From ba55a8d59834033e3c8daa904c6c9ebb4dcb884b Mon Sep 17 00:00:00 2001 From: Qianli Ma Date: Wed, 9 Apr 2025 20:16:31 -0700 Subject: [PATCH 01/10] feat: add post-training and custom-training support --- INSTALL.md | 23 +- cosmos_transfer1/POST_TRAINING.md | 3 - cosmos_transfer1/checkpointer/base.py | 127 ++ .../checkpointer/ddp_checkpointer.py | 437 +++++++ .../checkpointer/ema_fsdp_checkpointer.py | 54 + .../checkpointer/fsdp_checkpointer.py | 392 ++++++ .../checkpointer/fsdp_optim_fix.py | 351 ++++++ .../checkpointer/multi_rank_checkpointer.py | 236 ++++ .../checkpointer/safe_broadcast.py | 95 ++ .../checkpointer/tp_checkpointer.py | 42 + cosmos_transfer1/diffusion/conditioner.py | 2 + .../diffusion/config/base/conditioner.py | 31 + .../diffusion/config/base/model.py | 56 + cosmos_transfer1/diffusion/config/base/net.py | 4 +- cosmos_transfer1/diffusion/config/config.py | 3 + .../diffusion/config/config_train.py | 89 ++ .../cosmos-1-diffusion-control2world.py | 2 +- cosmos_transfer1/diffusion/config/registry.py | 4 +- .../diffusion/config/training/__init__.py | 0 .../diffusion/config/training/callbacks.py | 29 + .../diffusion/config/training/checkpoint.py | 27 + .../diffusion/config/training/ema.py | 27 + .../experiment/ctrl_7b_tp_121frames.py | 211 ++++ .../diffusion/config/training/optim.py | 40 + .../diffusion/config/training/registry.py | 78 ++ .../config/training/registry_extra.py | 44 + .../diffusion/config/transfer/blurs.py | 24 + .../diffusion/config/transfer/conditioner.py | 8 + .../diffusion/config/transfer/registry.py | 2 +- .../diffusion/functional/batch_ops.py | 61 + .../diffusion/model/model_ctrl.py | 286 ++++- cosmos_transfer1/diffusion/model/model_v2w.py | 101 +- cosmos_transfer1/diffusion/module/blocks.py | 49 + cosmos_transfer1/diffusion/module/parallel.py | 2 +- .../diffusion/module/position_embedding.py | 435 ++++++- .../diffusion/networks/general_dit.py | 592 +++++++-- .../networks/general_dit_ctrl_enc.py | 163 ++- .../networks/general_dit_video_conditioned.py | 133 +- .../diffusion/training/callbacks/every_n.py | 86 ++ .../diffusion/training/callbacks/grad_clip.py | 101 ++ .../training/callbacks/iter_speed.py | 82 ++ .../training/callbacks/low_precision.py | 41 + .../datasets/data_sources/item_dataset.py | 22 + .../training/datasets/dataset_utils.py | 311 +++++ .../training/datasets/dataset_video.py | 206 ++++ .../diffusion/training/functional/loss.py | 135 +++ .../training/functional/lr_scheduler.py | 178 +++ .../diffusion/training/modules/edm_sde.py | 43 + .../diffusion/training/tensor_parallel.py | 102 ++ cosmos_transfer1/diffusion/training/train.py | 130 ++ .../training/utils/optim_instantiate.py | 83 ++ cosmos_transfer1/utils/callback.py | 457 +++++++ cosmos_transfer1/utils/checkpointer.py | 237 ++++ cosmos_transfer1/utils/config.py | 163 ++- cosmos_transfer1/utils/easy_io/__init__.py | 14 + .../utils/easy_io/backends/__init__.py | 13 + .../utils/easy_io/backends/base_backend.py | 60 + .../utils/easy_io/backends/http_backend.py | 91 ++ .../utils/easy_io/backends/local_backend.py | 550 +++++++++ .../utils/easy_io/backends/registry_utils.py | 127 ++ cosmos_transfer1/utils/easy_io/easy_io.py | 1066 +++++++++++++++++ cosmos_transfer1/utils/easy_io/file_client.py | 450 +++++++ .../utils/easy_io/handlers/__init__.py | 29 + .../utils/easy_io/handlers/base.py | 44 + .../utils/easy_io/handlers/csv_handler.py | 42 + .../utils/easy_io/handlers/gzip_handler.py | 33 + .../easy_io/handlers/imageio_video_handler.py | 91 ++ .../utils/easy_io/handlers/json_handler.py | 49 + .../utils/easy_io/handlers/jsonl_handler.py | 80 ++ .../utils/easy_io/handlers/np_handler.py | 89 ++ .../utils/easy_io/handlers/pandas_handler.py | 31 + .../utils/easy_io/handlers/pickle_handler.py | 42 + .../utils/easy_io/handlers/pil_handler.py | 96 ++ .../utils/easy_io/handlers/registry_utils.py | 80 ++ .../utils/easy_io/handlers/tarfile_handler.py | 39 + .../utils/easy_io/handlers/torch_handler.py | 34 + .../easy_io/handlers/torchjit_handler.py | 34 + .../utils/easy_io/handlers/txt_handler.py | 34 + .../utils/easy_io/handlers/yaml_handler.py | 38 + cosmos_transfer1/utils/ema.py | 327 +++++ cosmos_transfer1/utils/fused_adam.py | 398 ++++++ cosmos_transfer1/utils/lazy_config/lazy.py | 154 +++ cosmos_transfer1/utils/misc.py | 15 + cosmos_transfer1/utils/model.py | 136 +++ .../utils/parallel_state_helper.py | 24 + cosmos_transfer1/utils/trainer.py | 279 +++++ .../post-training_cosmos_transfer_7b_edge.md | 211 ++++ scripts/get_t5_embeddings.py | 126 ++ scripts/test_environment.py | 17 + 89 files changed, 11412 insertions(+), 171 deletions(-) delete mode 100644 cosmos_transfer1/POST_TRAINING.md create mode 100644 cosmos_transfer1/checkpointer/base.py create mode 100644 cosmos_transfer1/checkpointer/ddp_checkpointer.py create mode 100644 cosmos_transfer1/checkpointer/ema_fsdp_checkpointer.py create mode 100644 cosmos_transfer1/checkpointer/fsdp_checkpointer.py create mode 100644 cosmos_transfer1/checkpointer/fsdp_optim_fix.py create mode 100644 cosmos_transfer1/checkpointer/multi_rank_checkpointer.py create mode 100644 cosmos_transfer1/checkpointer/safe_broadcast.py create mode 100644 cosmos_transfer1/checkpointer/tp_checkpointer.py create mode 100644 cosmos_transfer1/diffusion/config/config_train.py create mode 100644 cosmos_transfer1/diffusion/config/training/__init__.py create mode 100644 cosmos_transfer1/diffusion/config/training/callbacks.py create mode 100644 cosmos_transfer1/diffusion/config/training/checkpoint.py create mode 100644 cosmos_transfer1/diffusion/config/training/ema.py create mode 100644 cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py create mode 100644 cosmos_transfer1/diffusion/config/training/optim.py create mode 100644 cosmos_transfer1/diffusion/config/training/registry.py create mode 100644 cosmos_transfer1/diffusion/config/training/registry_extra.py create mode 100644 cosmos_transfer1/diffusion/functional/batch_ops.py create mode 100644 cosmos_transfer1/diffusion/training/callbacks/every_n.py create mode 100644 cosmos_transfer1/diffusion/training/callbacks/grad_clip.py create mode 100644 cosmos_transfer1/diffusion/training/callbacks/iter_speed.py create mode 100644 cosmos_transfer1/diffusion/training/callbacks/low_precision.py create mode 100644 cosmos_transfer1/diffusion/training/datasets/data_sources/item_dataset.py create mode 100644 cosmos_transfer1/diffusion/training/datasets/dataset_utils.py create mode 100644 cosmos_transfer1/diffusion/training/datasets/dataset_video.py create mode 100644 cosmos_transfer1/diffusion/training/functional/loss.py create mode 100644 cosmos_transfer1/diffusion/training/functional/lr_scheduler.py create mode 100644 cosmos_transfer1/diffusion/training/modules/edm_sde.py create mode 100644 cosmos_transfer1/diffusion/training/tensor_parallel.py create mode 100644 cosmos_transfer1/diffusion/training/train.py create mode 100644 cosmos_transfer1/diffusion/training/utils/optim_instantiate.py create mode 100644 cosmos_transfer1/utils/callback.py create mode 100644 cosmos_transfer1/utils/checkpointer.py create mode 100644 cosmos_transfer1/utils/easy_io/__init__.py create mode 100644 cosmos_transfer1/utils/easy_io/backends/__init__.py create mode 100644 cosmos_transfer1/utils/easy_io/backends/base_backend.py create mode 100644 cosmos_transfer1/utils/easy_io/backends/http_backend.py create mode 100644 cosmos_transfer1/utils/easy_io/backends/local_backend.py create mode 100644 cosmos_transfer1/utils/easy_io/backends/registry_utils.py create mode 100644 cosmos_transfer1/utils/easy_io/easy_io.py create mode 100644 cosmos_transfer1/utils/easy_io/file_client.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/__init__.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/base.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/csv_handler.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/gzip_handler.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/imageio_video_handler.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/json_handler.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/jsonl_handler.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/np_handler.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/pandas_handler.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/pickle_handler.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/pil_handler.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/registry_utils.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/tarfile_handler.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/torch_handler.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/torchjit_handler.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/txt_handler.py create mode 100644 cosmos_transfer1/utils/easy_io/handlers/yaml_handler.py create mode 100644 cosmos_transfer1/utils/ema.py create mode 100644 cosmos_transfer1/utils/fused_adam.py create mode 100644 cosmos_transfer1/utils/model.py create mode 100644 cosmos_transfer1/utils/parallel_state_helper.py create mode 100644 cosmos_transfer1/utils/trainer.py create mode 100644 examples/post-training_cosmos_transfer_7b_edge.md create mode 100644 scripts/get_t5_embeddings.py diff --git a/INSTALL.md b/INSTALL.md index cece167d..92061bd8 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -34,4 +34,25 @@ CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/test_environment.py ### Post-training -Coming soon! +The below commands creates the `cosmos-transfer` conda environment and installs the dependencies for post-training. This is the same as required for inference but with an additional package `apex` for training with bfloat16. +```bash +# Create the cosmos-transfer1 conda environment. +conda env create --file cosmos-transfer1.yaml +# Activate the cosmos-transfer1 conda environment. +conda activate cosmos-transfer1 +# Install the dependencies. +pip install -r requirements.txt +# Patch Transformer engine linking issues in conda environments. +ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/ +ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/python3.10 +# Install Transformer engine. +pip install transformer-engine[pytorch]==1.12.0 +# Install Apex for full training with bfloat16. +git clone https://github.com/NVIDIA/apex +CUDA_HOME=$CONDA_PREFIX pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./apex +``` + +You can test the environment setup for post-training with +```bash +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/test_environment.py --training +``` diff --git a/cosmos_transfer1/POST_TRAINING.md b/cosmos_transfer1/POST_TRAINING.md deleted file mode 100644 index e0106a4e..00000000 --- a/cosmos_transfer1/POST_TRAINING.md +++ /dev/null @@ -1,3 +0,0 @@ -# Cosmos-Transfer1 Post-training - -Cosmos-Transfer1 post-training is coming soon! diff --git a/cosmos_transfer1/checkpointer/base.py b/cosmos_transfer1/checkpointer/base.py new file mode 100644 index 00000000..bd4490c6 --- /dev/null +++ b/cosmos_transfer1/checkpointer/base.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +from cosmos_transfer1.utils import callback +from cosmos_transfer1.utils.config import CheckpointConfig, JobConfig +from cosmos_transfer1.utils.easy_io import easy_io +from cosmos_transfer1.utils.model import Model + + +class AbstractCheckpointer(ABC): + """The checkpointer class. Supports checkpoint saving/loading to local disk.""" + + def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): + """Constructor of the checkpointer. + + Args: + config_checkpoint (CheckpointConfig): The config object for the checkpointer. + """ + self.config_checkpoint = config_checkpoint + # Set the callback functions. + self.callbacks = callbacks + + # Set checkpoint directories for local paths + self._local_dirname = os.path.join(config_job.path_local, "checkpoints") + + self.strict_resume = config_checkpoint.strict_resume + self.load_path = config_checkpoint.load_path or None + self.load_training_state = config_checkpoint.load_training_state + self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state + self.save_thread = None + self.verbose = config_checkpoint.verbose + self.keys_not_to_resume = config_checkpoint.keys_not_to_resume + self.broadcast_via_filesystem = config_checkpoint.broadcast_via_filesystem + + @abstractmethod + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> None: + pass + + @abstractmethod + def load( + self, + model: Model, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, + grad_scaler: Optional[torch.amp.GradScaler] = None, + ) -> int: + pass + + @property + def save_bucket(self): + """Get the bucket name for saving checkpoints.""" + return None + + @property + def load_bucket(self): + """Get the bucket name for loading checkpoints.""" + return None + + @property + def save_dirname(self): + return self._local_dirname + + @property + def load_dirname(self): + return self._local_dirname + + def finalize(self) -> None: + """Finalize the checkpointer.""" + if self.save_thread: + self.save_thread.join() + + def _read_latest_checkpoint_file(self) -> str | None: + """Get the file name of the latest saved checkpoint. If it doesn't exist, return None. + + Returns: + checkpoint_file (str | None): file name of the latest saved checkpoint. + """ + checkpoint_file = None + checkpoint_path = os.path.join(self.load_dirname, "latest_checkpoint.txt") + if easy_io.exists(checkpoint_path): + checkpoint_file = easy_io.load(checkpoint_path).strip() + + return checkpoint_file + + def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: + """Track the file name of the latest saved checkpoint. + + Args: + checkpoint_file (str): file name of the latest saved checkpoint. + """ + content = f"{checkpoint_file}\n" + checkpoint_path = os.path.join(self.save_dirname, "latest_checkpoint.txt") + easy_io.dump(content, checkpoint_path) + + def _check_checkpoint_exists(self, checkpoint_path: str) -> None: + """If the file checkpoint_path does not exist, raise an error. + + Args: + checkpoint_path (str): full path to the checkpoint. + """ + if not easy_io.exists(checkpoint_path): + raise FileNotFoundError(f"File not found: {checkpoint_path}") diff --git a/cosmos_transfer1/checkpointer/ddp_checkpointer.py b/cosmos_transfer1/checkpointer/ddp_checkpointer.py new file mode 100644 index 00000000..6bab42c4 --- /dev/null +++ b/cosmos_transfer1/checkpointer/ddp_checkpointer.py @@ -0,0 +1,437 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import threading +from collections import namedtuple +from typing import Any, Dict, Optional, Set, Tuple, Union + +import torch +import torch.distributed +from megatron.core import parallel_state +from torch.distributed import ProcessGroup, get_process_group_ranks + +from cosmos_transfer1.checkpointer.base import AbstractCheckpointer +from cosmos_transfer1.checkpointer.safe_broadcast import broadcast_object +from cosmos_transfer1.utils import distributed, log, misc +from cosmos_transfer1.utils.easy_io import easy_io +from cosmos_transfer1.utils.model import Model + +StateDictItemPath = namedtuple("StateDictItemPath", ["state_dict", "save_path"]) + + +class Checkpointer(AbstractCheckpointer): + """ + Checkpointer for DDP. + Note: This implementation only supports local filesystem. + """ + + KEYS_TO_SAVE = ["model", "optim", "scheduler", "trainer"] + KEYS_TO_POSTFIX = { + "model": "model", + "optim": "optim", + "scheduler": "scheduler", + "trainer": "", + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() + ep_world_size = parallel_state.get_expert_model_parallel_world_size() + assert pp_world_size < 2, "Pipeline Parallelism (PP) is not tested yet." + assert ep_world_size < 2, "Expert Parallelism (EP) is not tested yet." + self.mp_world_size = parallel_state.get_model_parallel_group().size() + if self.mp_world_size > 1 and self.__class__ == Checkpointer: + raise NotImplementedError( + "Model Parallelism (MP) is enabled - " + "you should use TensorParallel Checkpointer instead of DDP Checkpointer." + ) + # DDP rank (with context parallelism considered) + self.rank_dp_w_cp = parallel_state.get_data_parallel_rank(with_context_parallel=True) + # Context parallelism rank + self.cp_rank = parallel_state.get_context_parallel_rank() + # Model parallelism rank (including Tensor+Pipeline+Expert Parallelisms) + self.mp_rank = parallel_state.get_model_parallel_group().rank() + # self.mp_rank = parallel_state.get_model_parallel_group(with_expert_parallel=ep_world_size > 1).rank() + if self.broadcast_via_filesystem: + log.info("Broadcasting checkpoint data via the local filesystem.") + if not self.strict_resume: + log.warning("Strict resume mode is off. Some model parameters may not be loaded.") + + # collect ranks of all model parallel groups + all_ranks = [None for _ in range(distributed.get_world_size())] + torch.distributed.all_gather_object( + all_ranks, get_process_group_ranks(parallel_state.get_model_parallel_group()) + ) + all_ranks = list(set(tuple(rank) if isinstance(rank, list) else rank for rank in all_ranks)) + for ranks in all_ranks: + group = torch.distributed.new_group(list(ranks), backend="gloo") + if distributed.get_rank() in ranks: + self.mp_gloo_pg = group + + self.print("Checkpointer Initialized.") + + def print(self, message: str): + """ + Print message to the console. Include the parallelism rank information when verbose is set to True. + """ + if self.verbose: + log.info( + f"[Parallelism Rank: DP-{self.rank_dp_w_cp}, TP-{self.mp_rank}, CP-{self.cp_rank}]: {message}", + rank0_only=False, + ) + else: + log.info(message, rank0_only=True) + + def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: + del model + assert key in self.KEYS_TO_SAVE + post_fix = self.KEYS_TO_POSTFIX[key] + + if post_fix: + _ckpt_path = checkpoint_path.replace(".pt", f"_{post_fix}.pt") + else: + _ckpt_path = checkpoint_path + return _ckpt_path + + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> None: + """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + iteration (int): Current iteration number. + """ + self.callbacks.on_save_checkpoint_start(model, iteration) + + checkpoint_file = self.format_checkpoint_filename(model, iteration) + state_dict = self.generate_save_state_dict(model, optimizer, scheduler, grad_scaler, iteration) + state_dict = self._map_state_dict_path_during_save(state_dict, checkpoint_file, model) + if state_dict: + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker, + daemon=False, + args=(state_dict, checkpoint_file, distributed.get_rank()), + ) + self.save_thread.start() + + # Note: Checkpoints are saved on a separate thread and this callback is not accurate. + # Please check logs from on_save_checkpoint_success() for better accuracy + self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) + + def _map_state_dict_path_during_save(self, state_dict, checkpoint_file, model) -> dict[str, StateDictItemPath]: + new_dict = {} + for key, _state_dict in state_dict.items(): + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_file, model) + checkpoint_path = os.path.join(self.save_dirname, _ckpt_path) + new_dict[key] = StateDictItemPath(_state_dict, checkpoint_path) + return new_dict + + @misc.timer("checkpoint saving") + def _save_worker(self, state_dict: dict[str, StateDictItemPath], checkpoint_file: str, rank: int = 0) -> None: + """Worker to save checkpoint to disk, spawned with a child thread (in parallel with the training). + + Args: + state_dict (dict[str, StateDictItemPath]): The state dict of the model/optimizer/scheduler. + checkpoint_file (str): The file name of the model checkpoint. + rank (int): GPU device (default: 0). + """ + try: + for key, item in state_dict.items(): + self.print(f"Saving {key} to {item.save_path}") + try: + easy_io.dump( + item.state_dict, + item.save_path, + fast_backend=True, # optional for fast backend, cpu heavy + ) + self.print(f"Saved {key} to {item.save_path}") + except Exception as e: + self.print(f"Failed to save {key} to {item.save_path}: {str(e)}") + raise # Re-raise the exception after logging + + # Synchronize only rank 0 of each model parallel group + if self.mp_world_size > 1: + torch.distributed.barrier(group=self.mp_gloo_pg) + + # Only rank 0 of MP group and rank 0 of DP with CP updates latest_checkpoint.txt + if self.mp_rank == 0 and self.rank_dp_w_cp == 0: + self._write_latest_checkpoint_file(checkpoint_file) + + if distributed.get_rank() == 0: # only rank 0 saves trained_data_record + if "trained_data_record" in state_dict["model"].state_dict: + self._write_trained_data_record( + checkpoint_file, state_dict["model"].state_dict["trained_data_record"] + ) + + iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) + self.callbacks.on_save_checkpoint_success(iteration=iteration) + except Exception as e: # noqa: BLE001 + log.exception(f"Checkpoint failed to upload: {e}", rank0_only=not self.verbose) + + def format_checkpoint_filename(self, model: Model, iteration: int) -> str: + """Generate the checkpoint file name. + + Args: + iteration (int): The current iteration number. + + Returns: + checkpoint_file (str): The checkpoint file name. + """ + del self, model + return f"iter_{iteration:09}.pt" + + @misc.timer("generate saving state dict") + def generate_save_state_dict( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> Optional[Dict[str, Any]]: + state_dict = {} + + if self.rank_dp_w_cp == 0: + trainer_state = dict( + grad_scaler=grad_scaler.state_dict(), + iteration=iteration, + ) + model_state = model.state_dict() + optim_state = optimizer.state_dict() + scheduler_state = scheduler.state_dict() + self.callbacks.on_save_checkpoint(model, state_dict=trainer_state) + + trainer_state, model_state, optim_state, scheduler_state = misc.to( + [trainer_state, model_state, optim_state, scheduler_state], device="cpu" + ) + + state_dict = { + "model": model_state, + "optim": optim_state, + "scheduler": scheduler_state, + } + if distributed.get_rank() == 0: # only rank 0 saves trainer state + state_dict["trainer"] = trainer_state + return state_dict + return state_dict + + def load_broadcast_state_dict(self, checkpoint_path: str, model: Model, resume_keys: Set) -> dict[str, Any]: + """ + Load state_dict and broadcast. + + The main steps are: + 1. Download TP-rank-specific checkpoints for every GPU of DDP-rank 0 and CP-rank 0. + 2. Each rank loads its corresponding checkpoint from the local cache or receives it via broadcast. + + This approach ensures that each MP rank loads its specific part of the model, which is + crucial for Model Parallelism where different parts of the model are distributed across + multiple GPUs. + + When using Model Parallelism (e.g., Tensor Parallelism), the `broadcast_via_filesystem` option can + be set to True. This allows each rank to load its specific checkpoint from the local filesystem + instead of receiving it via network broadcast, which could be more efficient in some cases. + + For standard DDP without TP, `broadcast_via_filesystem` should remain False (default). + + Args: + checkpoint_path (str): The base path of the checkpoint. + model (Model): The model being loaded. + resume_keys (Set): Set of keys to resume from the checkpoint. + + Returns: + dict[str, Any]: A dictionary containing the loaded state for each resumed key. + """ + state_dict = {} + sorted_resume_keys = sorted(resume_keys) + # Step 1: Download TP-rank-specific checkpoints for every GPU of DDP-rank 0 and CP-rank 0. + if self.rank_dp_w_cp == 0: + for key in sorted_resume_keys: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) + if os.path.exists(local_cache_path): + # If the local checkpoint exists, we can directly load it + self.print(f"Checkpoint is already in local cache: {local_cache_path}. Loading...") + _state_dict = easy_io.load(local_cache_path, fast_backend=True) + else: + _state_dict = easy_io.load(_ckpt_path, fast_backend=True) + self.print(f"Downloading checkpoint from: {_ckpt_path}") + if self.broadcast_via_filesystem: + # Save the checkpoint to the local filesystem + easy_io.dump(_state_dict, local_cache_path, fast_backend=True) + state_dict[key] = _state_dict + # Ensure all ranks wait for the download to complete + distributed.barrier() + + # Step 2: Broadcast checkpoint data + log.info( + "Start broadcasting checkpoint from the source rank to all other ranks in the same DDP group.", + rank0_only=True, + ) + for key in sorted_resume_keys: + if self.broadcast_via_filesystem: + # Load the checkpoint from the local filesystem for other ranks + if self.rank_dp_w_cp != 0: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) + self.print(f"Loading checkpoint from: {local_cache_path}") + state_dict[key] = easy_io.load(local_cache_path, fast_backend=True) + else: + # Broadcast the checkpoint to all GPUs of the current DDP rank + group: ProcessGroup = parallel_state.get_data_parallel_group(with_context_parallel=True) + min_rank = min(get_process_group_ranks(group)) + + _state_dict = broadcast_object( + state_dict[key] if self.rank_dp_w_cp == 0 else None, + min_rank, + group=group, + device=torch.device(torch.cuda.current_device()), + ) + if self.rank_dp_w_cp == 0: + self.print(f'Broadcasted checkpoint["{key}"] to all other ranks in the same DDP group.') + else: + state_dict[key] = _state_dict + self.print(f'Received checkpoint["{key}"] from source rank {min_rank}.') + + return state_dict + + def keys_to_resume_during_load(self) -> Tuple[Set, Union[str, None]]: + latest_checkpoint_file = self._read_latest_checkpoint_file() + + resume_keys = [] + + if latest_checkpoint_file is not None: + # 1. Resume training from latest_checkpoint.txt under the same name. + checkpoint_path = os.path.join(self.load_dirname, latest_checkpoint_file) + resume_keys.extend(self.KEYS_TO_SAVE) + else: + if self.load_path: + # 2. Load the module weights specified by config_checkpoint.path. + checkpoint_path = self.load_path + if self.load_training_state: + resume_keys.extend(self.KEYS_TO_SAVE) + else: + resume_keys.append("model") + if self.only_load_scheduler_state: + resume_keys.append("scheduler") + else: + checkpoint_path = None + if len(self.keys_not_to_resume) > 0: + for key in self.keys_not_to_resume: + assert key in self.KEYS_TO_SAVE, f"Invalid key to resume: {key} not in {self.KEYS_TO_SAVE}" + resume_keys = [key for key in resume_keys if key not in self.keys_not_to_resume] + return set(resume_keys), checkpoint_path + + @misc.timer("checkpoint loading") + def load( + self, + model: Model, + optimizer: torch.optim.Optimizer | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> int: + """Load network weights and optimizer states from a checkpoint in a single process. + + The priority of the checkpoint loading logic is: + 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. + 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. + - This is typically used for inference mode. + - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. + 3. If none of the above, randomly initialize the model parameters and train from scratch. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). + scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). + grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). + + Returns: + iteration (int): the iteration number to start/resume from. + """ + self.callbacks.on_load_checkpoint_start(model) + + resume_keys, checkpoint_path = self.keys_to_resume_during_load() + + iteration = 0 + + # Load checkpoint. + if checkpoint_path is not None: + self._check_checkpoint_exists(checkpoint_path) + state_dict = self.load_broadcast_state_dict(checkpoint_path, model, set(resume_keys)) + + if "trainer" in state_dict: + trainer_state = state_dict["trainer"] + log.critical(state_dict.keys(), rank0_only=False) + log.critical(trainer_state, rank0_only=False) + log.info("- Loading the gradient scaler...") + grad_scaler.load_state_dict(trainer_state["grad_scaler"]) + self.callbacks.on_load_checkpoint(model, state_dict=trainer_state) + iteration = trainer_state["iteration"] + if "optim" in state_dict: + assert optimizer + optimizer_state = state_dict["optim"] + log.info("- Loading the optimizer...") + optimizer.load_state_dict(optimizer_state) + if "scheduler" in state_dict: + assert scheduler + scheduler_state = state_dict["scheduler"] + log.info("- Loading the scheduler...") + scheduler.load_state_dict(scheduler_state) + scheduler.last_epoch = iteration + if "model" in state_dict: + model_state = state_dict["model"] + log.info("- Loading the model...") + # model.load_state_dict(model_state) + if self.strict_resume: + log.info("\t Strict resume mode is on.") + else: + log.info("\t Strict resume mode is off.") + model_load_info = model.load_state_dict(model_state, strict=self.strict_resume) + log.info(f"\t {model_load_info}") + self.print(f"Loaded checkpoint from {checkpoint_path} in iteration {iteration}") + else: + log.info("Training from scratch.") + torch.cuda.empty_cache() + + self.callbacks.on_load_checkpoint_end(model) + + return iteration + + def _write_trained_data_record(self, checkpoint_file: str, trained_data_record: dict[str, int]) -> None: + """Write json file to save number of seen samples and number of iterations. + + Args: + checkpoint_file (str): iteration number for the saved checkpoint + trained_data_record (dict[str, int]): example {"image": 0, "video": 0, "iteration": 0}. + """ + # filename: iter_xxxxxxxxx_trained_data_record.json + checkpoint_path = os.path.join( + self.save_dirname, f"{checkpoint_file.replace('.pt', '')}_trained_data_record.json" + ) + easy_io.dump(trained_data_record, checkpoint_path) diff --git a/cosmos_transfer1/checkpointer/ema_fsdp_checkpointer.py b/cosmos_transfer1/checkpointer/ema_fsdp_checkpointer.py new file mode 100644 index 00000000..4553ef5c --- /dev/null +++ b/cosmos_transfer1/checkpointer/ema_fsdp_checkpointer.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import attrs + +from cosmos_transfer1.utils import log +from cosmos_transfer1.utils.config import CheckpointConfig as BaseCheckpointConfig +from cosmos_transfer1.utils.config import make_freezable +from cosmos_transfer1.checkpointer.fsdp_checkpointer import FSDPCheckpointer as BaseFSDPCheckpointer + + +@make_freezable +@attrs.define(slots=False) +class CheckpointConfig(BaseCheckpointConfig): + load_ema_to_reg: bool = False + + +class FSDPCheckpointer(BaseFSDPCheckpointer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not isinstance(self.config_checkpoint, CheckpointConfig): + warnings.warn( + "The 'config_checkpoint' is not an instance of 'CheckpointConfig'. " + "This behavior is deprecated and will not be supported in future versions. " + "Please update 'config_checkpoint' to be of type 'CheckpointConfig'.", + DeprecationWarning, + ) + + self.load_ema_to_reg = False + else: + self.load_ema_to_reg = self.config_checkpoint.load_ema_to_reg + + log.critical(f"load_ema_to_reg: {self.load_ema_to_reg}", rank0_only=False) + + def load_model_during_init(self, model, is_ema: bool = False, ema_id: int = 0): + if self.load_ema_to_reg and is_ema is False: + is_ema = True + ema_id = 0 + log.critical("Loading EMA model to regular model during initialization.", rank0_only=False) + super().load_model_during_init(model, is_ema, ema_id) diff --git a/cosmos_transfer1/checkpointer/fsdp_checkpointer.py b/cosmos_transfer1/checkpointer/fsdp_checkpointer.py new file mode 100644 index 00000000..21b2fbb3 --- /dev/null +++ b/cosmos_transfer1/checkpointer/fsdp_checkpointer.py @@ -0,0 +1,392 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import threading + +import torch +from torch.distributed.fsdp import FullOptimStateDictConfig, FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import StateDictType + +from cosmos_transfer1.utils import callback, distributed, log, misc +from cosmos_transfer1.utils.config import CheckpointConfig, JobConfig +from cosmos_transfer1.checkpointer.fsdp_optim_fix import scatter_full_optim_state_dict +from cosmos_transfer1.utils.model import Model + + +class FSDPCheckpointer: + """The checkpointer class. Supports checkpoint saving/loading to local disk.""" + + def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): + """Constructor of the checkpointer. + + Args: + config_checkpoint (CheckpointConfig): The config object for the checkpointer. + """ + # Set the callback functions. + self.callbacks = callbacks + self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints" + self.strict_resume = config_checkpoint.strict_resume + self.load_path = config_checkpoint.load_path + self.load_training_state = config_checkpoint.load_training_state + self.save_thread = None + self.config_checkpoint = config_checkpoint + + def _load_ckpt_file_during_init(self): + latest_checkpoint_file = self._read_latest_checkpoint_file() + if latest_checkpoint_file is not None: + # 1. Resume training from latest_checkpoint.txt under the same name. + checkpoint_dir = self.checkpoint_dir_local + checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) + resume = True + log.critical(f"[Checkpoint] Found latest checkpoint file: {latest_checkpoint_file}") + log.critical(f"[Checkpoint] Loading from local path: {checkpoint_path}") + log.critical("[Checkpoint] Will resume full training state (model, optimizer, scheduler)") + else: + if self.load_path: + # 2. Load the module weights specified by config_checkpoint.path. + checkpoint_path = self.load_path + resume = self.load_training_state + log.critical(f"[Checkpoint] Using specified checkpoint path: {checkpoint_path}") + if resume: + log.critical("[Checkpoint] Will load complete training state (model, optimizer, scheduler)") + else: + log.critical("[Checkpoint] Will load model weights only (no optimizer/scheduler state)") + else: + # 3. Randomly initialize the model parameters and train from scratch. + checkpoint_path = None + resume = False + log.critical("[Checkpoint] No checkpoint path specified") + log.critical("[Checkpoint] Starting fresh training with random initialization") + return checkpoint_path, resume + + @misc.timer("FSDP.load_model_during_init") + def load_model_during_init(self, model, is_ema=False, ema_id: int = 0): + if ema_id > 0: + assert is_ema, "ema_id should be used with is_ema=True" + checkpoint_path, _ = self._load_ckpt_file_during_init() + if checkpoint_path is not None: + tag = "reg" if not is_ema else "ema" + default_checkpoint_path = checkpoint_path.replace(".pt", f"_{tag}_model.pt") + if not os.path.exists(default_checkpoint_path): + default_checkpoint_path = checkpoint_path # starting from the release checkpoint + log.warning(f"is_ema={is_ema} model is not found. Loading from {default_checkpoint_path}") + if tag == "ema" and ema_id > 0: + _checkpoint_path = checkpoint_path.replace(".pt", f"_RANK{ema_id}.pt") + _checkpoint_path = _checkpoint_path.replace(".pt", f"_{tag}_model.pt") + if self._check_checkpoint_exists(_checkpoint_path, is_raise=False): + default_checkpoint_path = _checkpoint_path + else: + print( + f"{distributed.get_rank()}: Checkpoint not found: {_checkpoint_path} " + f"(fallback to {default_checkpoint_path})" + ) + checkpoint_path = default_checkpoint_path + self._check_checkpoint_exists(checkpoint_path) + + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + log.info("- Loading the model...") + if self.strict_resume: + log.info(model.load_state_dict(state_dict, strict=self.strict_resume)) + else: + log.critical("\t Using non-strict model") + from cosmos_transfer1.diffusion.inference.inference_utils import non_strict_load_model + + log.info(non_strict_load_model(model, state_dict)) + log.info("-finish model loading") + else: + log.info(f"is_ema={is_ema} model is not found and loaded.") + + @misc.timer("FSDP.load_optim_scheduler_during_init") + def load_optim_scheduler_during_init(self, fsdp_model, optimizer, scheduler): + checkpoint_path, resume = self._load_ckpt_file_during_init() + log.critical(f"Loading optimizer and scheduler: {checkpoint_path} (resume: {resume}") + if checkpoint_path is not None: + if resume: + checkpoint_path = checkpoint_path.replace(".pt", "_optim.pt") + self._check_checkpoint_exists(checkpoint_path) + if distributed.get_rank() == 0: + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load( + checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False + ) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + log.info("- Loading the optimizer (FSDP scatter)...") + else: + state_dict = { + "optimizer": None, + "scheduler": None, + } + distributed.barrier() + sharded_optimizer_state_dict = scatter_full_optim_state_dict( # <---- FSDP + state_dict["optimizer"], + fsdp_model, + ) + log.info("- Loading the optimizer (FSDP load_state_dict)...") + log.info(optimizer.load_state_dict(sharded_optimizer_state_dict)) + log.critical("Skip loading the scheduler...") + return + log.info("- Loading the scheduler...") + scheduler.load_state_dict(state_dict["scheduler"]) + + @misc.timer("FSDP get_optim_scheduler_state") + def get_optim_scheduler_state(self, optim, fsdp_model, scheduler): + with FSDP.state_dict_type( + fsdp_model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim) + scheduler_statedict = scheduler.state_dict() + return { + "optimizer": optim_statedict, + "scheduler": scheduler_statedict, + } + + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> None: + """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + iteration (int): Current iteration number. + """ + self.callbacks.on_save_checkpoint_start(model, iteration) + + model_state_dict = model.state_dict_model() + optim_scheduler_state_dict = self.get_optim_scheduler_state(optimizer, model.model, scheduler) + torch.cuda.empty_cache() + state_dict = dict( + iteration=iteration, + ) + self.callbacks.on_save_checkpoint(model, state_dict=state_dict) + + postfix, replicate_idx, shard_idx, total_ema_num = model.get_ckpt_postfix() + if replicate_idx == 0 and shard_idx == 0: + pass # save whole; it is rank0 + elif replicate_idx < total_ema_num and shard_idx == 0: + model_state_dict["model"] = None # only save ema + optim_scheduler_state_dict = None + state_dict = None + else: + return + + checkpoint_file = f"iter_{iteration:09}{postfix}.pt" + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker_local, + daemon=False, + args=(model_state_dict, optim_scheduler_state_dict, state_dict, checkpoint_file, distributed.get_rank()), + ) + self.save_thread.start() + + # Note: Checkpoints are saved on a separate thread and this callback is not accurate. + # Please check logs from on_save_checkpoint_success() for better accuracy + self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) + + @misc.timer("checkpoint saving (local)") + def _save_worker_local( + self, + model_state_dict: dict[str, torch.Tensor], + optim_scheduler_state_dict: dict[str, torch.Tensor], + state_dict: dict[str, torch.Tensor], + checkpoint_file: str, + rank: int = 0, + ) -> None: + """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training). + + Args: + state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler. + checkpoint_file (str): The file name of the model checkpoint. + rank (int): GPU device (default: 0). + """ + checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file) + os.makedirs(self.checkpoint_dir_local, exist_ok=True) + try: + model_state_dict, ema_model_state_dict = model_state_dict["model"], model_state_dict["ema"] + if model_state_dict is not None: + torch.save(model_state_dict, checkpoint_path.replace(".pt", "_reg_model.pt")) + if ema_model_state_dict is not None: + torch.save(ema_model_state_dict, checkpoint_path.replace(".pt", "_ema_model.pt")) + if optim_scheduler_state_dict is not None: + torch.save(optim_scheduler_state_dict, checkpoint_path.replace(".pt", "_optim.pt")) + if state_dict is not None: + torch.save(state_dict, checkpoint_path) + if rank == 0: + self._write_latest_checkpoint_file(checkpoint_file) + log.success(f"Saved checkpoint (local): {checkpoint_path}") + iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) + self.callbacks.on_save_checkpoint_success(iteration=iteration) + except Exception as e: # noqa: BLE001 + log.exception(f"Checkpoint failed to save (local): {e}") + + @misc.timer("checkpoint loading") + def load( + self, + model: Model, + optimizer: torch.optim.Optimizer | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> int: + """Load network weights and optimizer states from a checkpoint in a single process. + + The priority of the checkpoint loading logic is: + 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. + 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. + - This is typically used for inference mode. + - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. + 3. If none of the above, randomly initialize the model parameters and train from scratch. + + Args: + model (FSDPDiffModle): The PyTorch model. + optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). + scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). + grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). + + Returns: + iteration (int): the iteration number to start/resume from. + """ + self.callbacks.on_load_checkpoint_start(model) + + del optimizer, grad_scaler + checkpoint_path, resume = self._load_ckpt_file_during_init() + iteration = 0 + if checkpoint_path is not None: + self._check_checkpoint_exists(checkpoint_path) + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + self.callbacks.on_load_checkpoint(model, state_dict=state_dict) + if resume: + iteration = state_dict["iteration"] + log.success("Done with loading the checkpoint.") + else: + log.info("Training from scratch.") + torch.cuda.empty_cache() + + self.callbacks.on_load_checkpoint_end(model) + + if scheduler is not None: + scheduler.last_epoch = iteration + log.critical(f"resume scheduler from {iteration}", rank0_only=False) + + return iteration + + def _read_latest_checkpoint_file(self) -> str | None: + """Get the file name of the latest saved checkpoint. If it doesn't exist, return None. + + Returns: + checkpoint_file (str | None): file name of the latest saved checkpoint. + """ + checkpoint_file = None + latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") + if os.path.isfile(latest_path): + checkpoint_file = open(latest_path).read().strip() + if checkpoint_file is None: + log.warning(f"Latest ckpt file not found: {latest_path}") + else: + log.info(f"Found latest checkpoint: {checkpoint_file}") + return checkpoint_file + + def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: + """Track the file name of the latest saved checkpoint. + + Args: + checkpoint_file (str): file name of the latest saved checkpoint. + """ + content = f"{checkpoint_file}\n" + latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") + with open(latest_path, "w") as file: + file.write(content) + + def _check_checkpoint_exists(self, checkpoint_path: str, is_raise: bool = True) -> None: + """If the file checkpoint_path does not exist, raise an error. + + Args: + checkpoint_path (str): full path to the checkpoint. + """ + if not os.path.exists(checkpoint_path): + if is_raise: + raise FileNotFoundError(f"File not found (local): {checkpoint_path}") + return True + + def finalize(self) -> None: + """Finalize the checkpointer.""" + if self.save_thread: + self.save_thread.join() + + +class FSDPInferenceCheckpointer: + def __init__( + self, + ckpt_path: str, + strict_resume: bool = True, + ): + self.ckpt_path = ckpt_path + self.strict_resume = strict_resume + + @misc.timer("FSDPInferenceCheckpointer.load_model_during_init") + def load_model_during_init(self, model, is_ema=False, ema_id: int = 0): + del ema_id + if is_ema: + log.warning("EMA model is not supported in inference mode.") + return + assert os.path.exists(self.ckpt_path) + log.info(f"Loading from {self.ckpt_path}") + state_dict = torch.load(self.ckpt_path, map_location=lambda storage, loc: storage, weights_only=False) + if self.strict_resume: + log.info(model.load_state_dict(state_dict, strict=self.strict_resume)) + else: + log.critical("\t Using non-strict model") + from cosmos_transfer1.checkpointer.fsdp_checkpointer import non_strict_load_model + + log.info(non_strict_load_model(model, state_dict)) + log.info("-finish model loading") + + def load_optim_scheduler_during_init(self, *args, **kwargs): + """ + We do not do load in inference mode. The function is here to maintain the same interface to avoid errors. + """ + pass + + def save(self, *args, **kwargs): + """ + We do not save anything in inference mode. The function is here to maintain the same interface to avoid errors. + """ + pass + + def load(self, *args, **kwargs): + """ + We do not do load in inference mode. The function is here to maintain the same interface to avoid errors. + """ + return 0 diff --git a/cosmos_transfer1/checkpointer/fsdp_optim_fix.py b/cosmos_transfer1/checkpointer/fsdp_optim_fix.py new file mode 100644 index 00000000..a08aa943 --- /dev/null +++ b/cosmos_transfer1/checkpointer/fsdp_optim_fix.py @@ -0,0 +1,351 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +# isort: skip_file + +""" +torch 2.2 has bugs in loading optimizer states for FSDP in hybrid mode +torch impl uses state.rank and dist.rank() inconsistently +The file fix the bugs. Verified it works for hybrid mode and fullly sharded mode +Please use the `scatter_full_optim_state_dict` in the code to replace the corresponding function in torch 2.2 +""" + +import copy +import warnings +from typing import Any, Dict, Iterable, List, Optional, Union + +import torch +import torch.distributed as dist +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.fsdp._debug_utils import SimpleProfiler +from torch.distributed.fsdp._optim_utils import ( + _flatten_optim_state, + _FSDPState, + _get_fqn_to_fsdp_param_info, + _get_param_to_fqns, + _OptimStateKey, + _PosDimTensorInfo, + _shard_orig_param_state, + tree_map_only, +) +from torch.distributed.fsdp.fully_sharded_data_parallel import _rekey_sharded_optim_state_dict + + +def _broadcast_processed_state( + fsdp_state: _FSDPState, + optim_state: Dict[str, Any], + group: Optional[dist.ProcessGroup], +) -> Dict[str, Any]: + objects: List[Any] = [None] + if fsdp_state.rank == 0: + objects[0] = tree_map_only( + torch.Tensor, + lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype), + optim_state, + ) + dist.broadcast_object_list(objects, src=0, group=group) + if dist.get_rank() == 0: + return optim_state + else: + return objects[0] + + +def _broadcast_state(fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup]) -> Any: + if dist.get_rank() == 0: + if not isinstance(state, torch.Tensor) or state.dim() == 0: + return state + tensor = state.to(fsdp_state.compute_device) + else: + if isinstance(state, torch.Tensor): + assert state.dim() == 0, ( + "For non-zero ranks, a tensor state should have zero dimension, " + "but got the state with shape {state.shape()}." + ) + return state + elif not isinstance(state, _PosDimTensorInfo): + return state + tensor = torch.zeros(state.shape, dtype=state.dtype, device=fsdp_state.compute_device) + dist.broadcast(tensor, src=0, group=group) + return tensor + + +def _flatten_optim_state_dict( + optim_state_dict: Dict[str, Any], + model: nn.Module, + use_orig_params: bool = False, + optim: Optional[torch.optim.Optimizer] = None, + rank0_only: bool = False, + group: Optional[dist.ProcessGroup] = None, +) -> Dict[str, Any]: + """ + Flattens the full optimizer state dict, still keying by unflattened parameter + names. + + If ``use_orig_params`` is True, each rank will have all FSDP-managed + parameters but some of these parameters may be empty due to the sharding. + For a regular optim.Optimizer, states for those empty parameters will + not be initialized. So, when aggregating the FQNs across ranks, no assert + will be raised on a rank even if it does not have all the states -- it is + valid and FSDP know how to aggregate them. However, FSDP has to ignore + handling those parameters that are not managed by FSDP and do not exist on + the local rank -- it is managed by other parallelism and FSDP does not + know ho to handle/aggregate them. + + Note that ``_flatten_tensor_optim_state`` does not need ``optim`` to + flatten/shard the state. However, NamedOptimizer and KeyedOptimizer require + all the states even if the corresponding parameters are empty. To this end, + ``optim`` will be used to to get the initial state of the empty parameters. + ``optim`` should only be non-None if the ``optim` is KeyedOptimizer or + NamedOptimizer. + + Returns: + Dict[str, Any]: The flattened optimizer state dict. + """ + SimpleProfiler.reset() + + unflat_osd = optim_state_dict + if "state" not in unflat_osd and not rank0_only: + raise ValueError('`optim_state_dict` must have the keys "state"' "to be a valid optimizer state dict") + param_to_fqns = _get_param_to_fqns(model) + fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) + fsdp_state = next(iter(fqn_to_fsdp_param_info.values())).state + + # Broadcast unflat_osd without non-scalar tensor if rank0_only is True. + if rank0_only: + unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group) + + # Construct the "state" part + flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {} + unflat_osd_state = unflat_osd["state"] + all_state_keys = set(unflat_osd_state.keys()) + + for param, fqns in param_to_fqns.items(): + fqn = fqns[0] + if fqn not in unflat_osd_state: + continue + all_state_keys.difference_update(fqns) + + if rank0_only: + for fqn in fqns: + if not unflat_osd_state[fqn]: + continue + for state_name in unflat_osd_state[fqn].keys(): + unflat_osd_state[fqn][state_name] = _broadcast_state( + fsdp_state, unflat_osd_state[fqn][state_name], group=group + ) + fqn = fqns[0] + if fqn in fqn_to_fsdp_param_info: + fsdp_param_info = fqn_to_fsdp_param_info[fqn] + if use_orig_params: + with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING): + flat_state = _shard_orig_param_state( + fsdp_param_info, + fqn, + unflat_osd_state[fqn], + ) + else: + flat_state = _flatten_optim_state( + fsdp_param_info, + unflat_osd_state, + fqns, + ) + key = _OptimStateKey(tuple(fqns), True) + # Only include non-empty states since as expected by + # `torch.optim.Optimizer` s unless the optimizer is KeyedOptimizer + # or NamedOptimizer. + if flat_state: + flat_osd_state[key] = flat_state + elif use_orig_params: + assert len(fqns) == 1, f"use_orig_params is True but there are multiple FQNs, {fqns}." + if optim is not None: # NamedOptimizer or KeyedOptimizer case. + state = optim.state.get(param, None) # type: ignore[call-overload] + if state is not None: + flat_osd_state[key] = copy.deepcopy(state) + else: + warnings.warn(f"optim_state[{key}] is not on rank{fsdp_state.rank}.") + + else: + raise RuntimeError(f"The state of {key} is empty. This should happen when " "use_orig_params=True.") + else: # do not flatten non-FSDP parameters' states + assert len(fqns) == 1 + key = _OptimStateKey(tuple(fqns), False) + flat_osd_state[key] = copy.copy(unflat_osd_state[fqn]) + + if rank0_only: + for fqn in fqns: + if not unflat_osd_state[fqn]: + continue + for state_name, param_state in list(unflat_osd_state[fqn].items()): + if fsdp_state.rank > 0: + # Deference the tensor so that PyTorch can collect the memory. + del unflat_osd_state[fqn][state_name] + else: + # Move the tensor in the original osd back to CPU to make the + # original osd unaffected. + unflat_osd_state[fqn][state_name] = unflat_osd_state[fqn][state_name].cpu() + + # Handle user-defined state, states that are not associated with parameters. + for key in all_state_keys: + user_state = unflat_osd_state[key] + if isinstance(user_state, torch.Tensor) and rank0_only and use_orig_params: + user_state = _broadcast_state(fsdp_state, user_state, group=group) + flat_osd_state[key] = copy.copy(user_state) + + SimpleProfiler.dump_and_reset("FSDP _flatten_optim_state_dict() profiling: ") + # Construct the "param_groups" part -- copy as is since it will be + # rekeyed later according to the target rank's optimizer + # Only copy param_groups if it exists in unflat_osd + if "param_groups" in unflat_osd: + flat_osd_param_groups = copy.deepcopy(unflat_osd["param_groups"]) + return {"state": flat_osd_state, "param_groups": flat_osd_param_groups} + else: + return {"state": flat_osd_state} + + +def _optim_state_dict_to_load_impl( + optim_state_dict: Dict[str, Any], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + List[Dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + full_state_dict: bool = True, + rank0_only: bool = False, + is_named_optimizer: bool = False, + group: Optional[dist.ProcessGroup] = None, +) -> Dict[str, Any]: + """ + The internal API that is used by all the load optim_state_dict implementations. + Given model, optim, and the saved optim_state_dict, this API adds the FSDP + internal information and internal sharding to the optim_state_dict. + """ + if full_state_dict: + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, + optim, + ) + else: + using_optim_input = False + assert optim_input is None and not rank0_only + + use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[0]._use_orig_params + assert all( + use_orig_params == m._use_orig_params for m in FullyShardedDataParallel.fsdp_modules(model) + ), "Not all FSDP modules have the same _use_orig_params value" + + if rank0_only and dist.get_rank(group) > 0: + optim_state_dict = {} + sharded_osd = _flatten_optim_state_dict( + optim_state_dict, + model=model, + use_orig_params=use_orig_params, + optim=(optim if is_named_optimizer else None), + rank0_only=rank0_only, + group=group, + ) + return _rekey_sharded_optim_state_dict( + sharded_osd, + model=model, + optim=optim, + optim_input=optim_input, + using_optim_input=using_optim_input, + is_named_optimizer=is_named_optimizer, + ) + + +def scatter_full_optim_state_dict( + full_optim_state_dict: Optional[Dict[str, Any]], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + List[Dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + group: Optional[Any] = None, +) -> Dict[str, Any]: + """ + Scatters the full optimizer state dict from rank 0 to all other ranks, + returning the sharded optimizer state dict on each rank. The return + value is the same as :meth:`shard_full_optim_state_dict`, and on rank + 0, the first argument should be the return value of + :meth:`full_optim_state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> model, optim = ... + >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 + >>> # Define new model with possibly different world size + >>> new_model, new_optim, new_group = ... + >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) + >>> new_optim.load_state_dict(sharded_osd) + + .. note:: Both :meth:`shard_full_optim_state_dict` and + :meth:`scatter_full_optim_state_dict` may be used to get the + sharded optimizer state dict to load. Assuming that the full + optimizer state dict resides in CPU memory, the former requires + each rank to have the full dict in CPU memory, where each rank + individually shards the dict without any communication, while the + latter requires only rank 0 to have the full dict in CPU memory, + where rank 0 moves each shard to GPU memory (for NCCL) and + communicates it to ranks appropriately. Hence, the former has + higher aggregate CPU memory cost, while the latter has higher + communication cost. + + Args: + full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state + dict corresponding to the unflattened parameters and holding + the full non-sharded optimizer state if on rank 0; the argument + is ignored on nonzero ranks. + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + correspond to the optimizer state in ``full_optim_state_dict``. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + optim (Optional[torch.optim.Optimizer]): Optimizer that will load + the state dict returned by this method. This is the preferred + argument to use over ``optim_input``. (Default: ``None``) + group (dist.ProcessGroup): Model's process group or ``None`` if + using the default process group. (Default: ``None``) + + Returns: + Dict[str, Any]: The full optimizer state dict now remapped to + flattened parameters instead of unflattened parameters and + restricted to only include this rank's part of the optimizer state. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict("scatter_full_optim_state_dict", "optim_state_dict_to_load") + return _optim_state_dict_to_load_impl( + optim_state_dict=full_optim_state_dict, + model=model, + optim_input=optim_input, + optim=optim, + full_state_dict=True, + rank0_only=True, + is_named_optimizer=False, + group=group, + ) diff --git a/cosmos_transfer1/checkpointer/multi_rank_checkpointer.py b/cosmos_transfer1/checkpointer/multi_rank_checkpointer.py new file mode 100644 index 00000000..d6408c91 --- /dev/null +++ b/cosmos_transfer1/checkpointer/multi_rank_checkpointer.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import threading +from typing import List, NamedTuple, Tuple + +import torch + +from cosmos_transfer1.utils import distributed, log, misc +from cosmos_transfer1.utils.checkpointer import Checkpointer as BaseCheckpointer +from cosmos_transfer1.utils.model import Model + +TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) +if TORCH_VERSION >= (1, 11): + from torch.ao import quantization + from torch.ao.quantization import FakeQuantizeBase, ObserverBase +elif ( + TORCH_VERSION >= (1, 8) + and hasattr(torch.quantization, "FakeQuantizeBase") + and hasattr(torch.quantization, "ObserverBase") +): + from torch import quantization + from torch.quantization import FakeQuantizeBase, ObserverBase + + +class _IncompatibleKeys( + NamedTuple( + "IncompatibleKeys", + [ + ("missing_keys", List[str]), + ("unexpected_keys", List[str]), + ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), + ], + ) +): + pass + + +class MultiRankCheckpointer(BaseCheckpointer): + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> None: + """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + iteration (int): Current iteration number. + """ + # checkpoint_file = f"iter_{iteration:09}.pt" + postfix, _, total_ema_num = model.get_ckpt_postfix() + checkpoint_file = f"iter_{iteration:09}{postfix}.pt" + save_ranks = list(range(total_ema_num)) + for _rank in save_ranks: + if distributed.get_rank() == _rank: + state_dict = dict( + model=model.state_dict(), + optimizer=optimizer.state_dict(), + scheduler=scheduler.state_dict(), + grad_scaler=grad_scaler.state_dict(), + iteration=iteration, + ) + state_dict = misc.to(state_dict, device="cpu") + self.callbacks.on_save_checkpoint(model, state_dict=state_dict) + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker_local, + daemon=False, + args=(state_dict, checkpoint_file, distributed.get_rank()), + ) + self.save_thread.start() + + @misc.timer("checkpoint loading") + def load( + self, + model: Model, + optimizer: torch.optim.Optimizer | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> int: + """Load network weights and optimizer states from a checkpoint in a single process. + + The priority of the checkpoint loading logic is: + 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. + 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. + - This is typically used for inference mode. + - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. + 3. If none of the above, randomly initialize the model parameters and train from scratch. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). + scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). + grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). + + Returns: + iteration (int): the iteration number to start/resume from. + """ + latest_checkpoint_file = self._read_latest_checkpoint_file() + if latest_checkpoint_file is not None: + # different from base checkpointer, this support multi-EMA + postfix, _, total_ema_num = model.get_ckpt_postfix() + latest_checkpoint_file = latest_checkpoint_file.replace(".pt", f"{postfix}.pt") + # 1. Resume training from latest_checkpoint.txt under the same name. + checkpoint_dir = self.checkpoint_dir_local + checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) + resume = True + else: + if self.load_path: + # 2. Load the module weights specified by config_checkpoint.path. + checkpoint_path = self.load_path + # different from base checkpointer, this support multi-EMA + postfix, _, total_ema_num = model.get_ckpt_postfix() + checkpoint_path = checkpoint_path.replace(".pt", f"{postfix}.pt") + resume = self.load_training_state + else: + # 3. Randomly initialize the model parameters and train from scratch. + checkpoint_path = None + resume = False + # Load checkpoint. + if checkpoint_path is not None: + self._check_checkpoint_exists(checkpoint_path) + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + self.callbacks.on_load_checkpoint(model, state_dict=state_dict) + # Load the state dicts. + log.info("- Loading the model...") + log.critical(model.load_state_dict(state_dict["model"], strict=self.strict_resume)) + if resume: + iteration = state_dict["iteration"] + assert optimizer and scheduler + log.info("- Loading the optimizer...") + optimizer.load_state_dict(state_dict["optimizer"]) + log.info("- Loading the scheduler...") + scheduler.load_state_dict(state_dict["scheduler"]) + scheduler.last_epoch = iteration + log.info("- Loading the gradient scaler...") + grad_scaler.load_state_dict(state_dict["grad_scaler"]) + log.success(f"Done with loading the checkpoint (iteration {iteration}).") + else: + iteration = 0 + log.success("Done with loading the checkpoint.") + else: + # Checkpoint not found and not specified. We will train everything from scratch. + iteration = 0 + log.info("Training from scratch.") + torch.cuda.empty_cache() + return iteration + + +# https://github.com/facebookresearch/fvcore/blob/9d683aae73fb899dd35d6cf6720e5ef567761c57/fvcore/common/checkpoint.py +def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: + # workaround https://github.com/pytorch/pytorch/issues/24139 + model_state_dict = model.state_dict() + incorrect_shapes = [] + for k in list(checkpoint_state_dict.keys()): + if k in model_state_dict: + if "_extra_state" in k: # Key introduced by TransformerEngine for FP8 + log.warning(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") + continue + model_param = model_state_dict[k] + # Allow mismatch for uninitialized parameters + if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): + continue + if not isinstance(model_param, torch.Tensor): + raise ValueError( + f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." + ) + + shape_model = tuple(model_param.shape) + shape_checkpoint = tuple(checkpoint_state_dict[k].shape) + if shape_model != shape_checkpoint: + has_observer_base_classes = ( + TORCH_VERSION >= (1, 8) + and hasattr(quantization, "ObserverBase") + and hasattr(quantization, "FakeQuantizeBase") + ) + if has_observer_base_classes: + # Handle the special case of quantization per channel observers, + # where buffer shape mismatches are expected. + def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: + # foo.bar.param_or_buffer_name -> [foo, bar] + key_parts = key.split(".")[:-1] + cur_module = model + for key_part in key_parts: + cur_module = getattr(cur_module, key_part) + return cur_module + + cls_to_skip = ( + ObserverBase, + FakeQuantizeBase, + ) + target_module = _get_module_for_key(model, k) + if isinstance(target_module, cls_to_skip): + # Do not remove modules with expected shape mismatches + # them from the state_dict loading. They have special logic + # in _load_from_state_dict to handle the mismatches. + continue + + incorrect_shapes.append((k, shape_checkpoint, shape_model)) + checkpoint_state_dict.pop(k) + incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) + # Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling + missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] + unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] + return _IncompatibleKeys( + missing_keys=missing_keys, + unexpected_keys=unexpected_keys, + incorrect_shapes=incorrect_shapes, + ) diff --git a/cosmos_transfer1/checkpointer/safe_broadcast.py b/cosmos_transfer1/checkpointer/safe_broadcast.py new file mode 100644 index 00000000..f914299c --- /dev/null +++ b/cosmos_transfer1/checkpointer/safe_broadcast.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import io +import pickle +from typing import Any + +import torch +import torch.distributed as dist + + +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/optim/zero_redundancy_optimizer.py#L29 +def broadcast_object( + obj: Any, + src_rank: int, + group: object = dist.group.WORLD, + device: torch.device = torch.device("cpu"), +) -> Any: + r""" + Broadcasts an object to the given group. + + It will be sending the object if called from the source rank and receiving + the object otherwise. + + Arguments: + obj: object to broadcast; only used if called on the source rank. + src_rank (int): source rank. + group (``ProcessGroup``, optional): group used for the broadcast + (default: ``dist.group.WORLD``). + device (``torch.device``, optional): device to send from or receive + to (default: ``torch.device("cpu")``). + + Returns: + The broadcasted object. + """ + if dist.get_rank() == src_rank: + # Send the object + buffer = io.BytesIO() + torch.save(obj, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL) + data = bytearray(buffer.getbuffer()) + length_tensor = torch.LongTensor([len(data)]).to(device) + data_send_tensor = torch.ByteTensor(data).to(device) + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) + else: + # Receive the object + length_tensor = torch.LongTensor([0]).to(device) + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=device) + dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) + buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) + obj = torch.load(buffer, map_location=device, weights_only=False) + return obj + + +def _recursive_copy_to_device( + value: Any, + non_blocking: bool, + device: torch.device, +) -> Any: + r""" + Recursively searches lists, tuples, dicts and copies tensors to device if possible. + + Non-tensor values are passed as-is in the result. + + .. note: These are all copies, so if there are two objects that reference + the same object, then after this call, there will be two different objects + referenced on the device. + """ + if isinstance(value, torch.Tensor): + return value.to(device, non_blocking=non_blocking) + + if isinstance(value, (list, tuple)): + values = [_recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for val in value] + return values if isinstance(value, list) else tuple(values) + + if isinstance(value, collections.abc.Mapping): + return { + key: _recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for key, val in value.items() + } + + return value diff --git a/cosmos_transfer1/checkpointer/tp_checkpointer.py b/cosmos_transfer1/checkpointer/tp_checkpointer.py new file mode 100644 index 00000000..0420857e --- /dev/null +++ b/cosmos_transfer1/checkpointer/tp_checkpointer.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_transfer1.checkpointer.ddp_checkpointer import Checkpointer as DDPCheckpointer +from cosmos_transfer1.utils.model import Model + + +class Checkpointer(DDPCheckpointer): + """ + Checkpointer class for Tensor Parallelism (TP) in distributed training. + + This implementation supports the combination of Tensor Parallelism (TP) and Data Parallel Processing (DDP), with optional Context Parallelism (CP). + + Note: + - Fully Sharded Data Parallelism (FSDP) is not supported by this checkpointer. + - In principle, this implementation is also compatible with Pipeline Parallelism (PP) and Expert Parallelism (EP), which are other forms of model parallelism. However, PP and EP have not been tested yet. + """ + + def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: + """ + Overwrite the `add_type_postfix_to_checkpoint_path` function of the base class (DDP checkpointer) + to append the TP-rank postfix to the checkpoint path. + """ + checkpoint_path = super().add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + if key == "trainer": + return checkpoint_path + else: + checkpoint_path = checkpoint_path.replace(".pt", f"_mp_{self.mp_rank}.pt") + + return checkpoint_path diff --git a/cosmos_transfer1/diffusion/conditioner.py b/cosmos_transfer1/diffusion/conditioner.py index 6962e103..00bfe588 100644 --- a/cosmos_transfer1/diffusion/conditioner.py +++ b/cosmos_transfer1/diffusion/conditioner.py @@ -132,6 +132,8 @@ class VideoExtendCondition(BaseVideoCondition): condition_video_input_mask: Optional[torch.Tensor] = None # condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation, only valid when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed" condition_video_augment_sigma: Optional[torch.Tensor] = None + # pose conditional input, will be concat with the input tensor + condition_video_pose: Optional[torch.Tensor] = None class GeneralConditioner(nn.Module, ABC): diff --git a/cosmos_transfer1/diffusion/config/base/conditioner.py b/cosmos_transfer1/diffusion/config/base/conditioner.py index 3007dd85..6a52df75 100644 --- a/cosmos_transfer1/diffusion/config/base/conditioner.py +++ b/cosmos_transfer1/diffusion/config/base/conditioner.py @@ -124,11 +124,42 @@ class VideoCondBoolConfig: obj: LazyDict = L(BooleanFlag)(output_key="video_cond_bool") dropout_rate: float = 0.2 input_key: str = "fps" # This is a placeholder, we never use this value + # Config below are for long video generation only + compute_loss_for_condition_region: bool = False # Compute loss for condition region + + # How to sample condition region during training. "first_random_n" set the first n frames to be condition region, n is random, "random" set the condition region to be random, + condition_location: str = "first_random_n" + random_conditon_rate: float = 0.5 # The rate to sample the condition region randomly + first_random_n_num_condition_t_max: int = 4 # The maximum number of frames to sample as condition region, used when condition_location is "first_random_n" + first_random_n_num_condition_t_min: int = 0 # The minimum number of frames to sample as condition region, used when condition_location is "first_random_n" + + # How to dropout value of the conditional input frames + cfg_unconditional_type: str = "zero_condition_region_condition_mask" # Unconditional type. "zero_condition_region_condition_mask" set the input to zero for condition region, "noise_x_condition_region" set the input to x_t, same as the base model + + # How to corrupt the condition region + apply_corruption_to_condition_region: str = "noise_with_sigma" # Apply corruption to condition region, option: "gaussian_blur", "noise_with_sigma", "clean" (inference), "noise_with_sigma_fixed" (inference) + # Inference only option: list of sigma value for the corruption at different chunk id, used when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed" + apply_corruption_to_condition_region_sigma_value: list[float] = [0.001, 0.2] + [ + 0.5 + ] * 10 # Sigma value for the corruption, used when apply_corruption_to_condition_region is "noise_with_sigma_fixed" + + # Add augment_sigma condition to the network + condition_on_augment_sigma: bool = False + # The following arguments is to match with previous implementation where we use train sde to sample augment sigma (with adjust video noise turn on) + augment_sigma_sample_p_mean: float = 0.0 # Mean of the augment sigma + augment_sigma_sample_p_std: float = 1.0 # Std of the augment sigma + augment_sigma_sample_multiplier: float = 4.0 # Multipler of augment sigma + + # Add pose condition to the network + add_pose_condition: bool = False # Sample PPP... from IPPP... sequence sample_tokens_start_from_p_or_i: bool = False + # Normalize the input condition latent + normalize_condition_latent: bool = False + @attrs.define(slots=False) class LatentConditionConfig: diff --git a/cosmos_transfer1/diffusion/config/base/model.py b/cosmos_transfer1/diffusion/config/base/model.py index d166fac8..3f93842c 100644 --- a/cosmos_transfer1/diffusion/config/base/model.py +++ b/cosmos_transfer1/diffusion/config/base/model.py @@ -17,7 +17,19 @@ import attrs +from cosmos_transfer1.diffusion.config.training.ema import PowerEMAConfig +from cosmos_transfer1.diffusion.training.modules.edm_sde import EDMSDE from cosmos_transfer1.utils.lazy_config import LazyDict +from cosmos_transfer1.utils.lazy_config import LazyCall as L + + +@attrs.define(slots=False) +class FSDPConfig: + policy: str = "block" + checkpoint: bool = False + min_num_params: int = 1024 + sharding_group_size: int = 8 + sharding_strategy: str = "full" @attrs.define(slots=False) @@ -30,6 +42,50 @@ class DefaultModelConfig: input_data_key: str = "video" # key to fetch input data from data_batch latent_shape: List[int] = [16, 24, 44, 80] # 24 corresponig to 136 frames + # training related + ema: LazyDict = PowerEMAConfig + sde: LazyDict = L(EDMSDE)( + p_mean=0.0, + p_std=1.0, + sigma_max=80, + sigma_min=0.0002, + ) + camera_sample_weight: LazyDict = LazyDict( + dict( + enabled=False, + weight=5.0, + ) + ) + aesthetic_finetuning: LazyDict = LazyDict( + dict( + enabled=False, + ) + ) + loss_mask_enabled: bool = False + loss_masking: LazyDict = None + loss_add_logvar: bool = True + input_image_key: str = "images_1024" # key to fetch input image from data_batch + loss_reduce: str = "sum" + loss_scale: float = 1.0 + fsdp_enabled: bool = False + use_torch_compile: bool = False + fsdp: FSDPConfig = attrs.field(factory=FSDPConfig) + use_dummy_temporal_dim: bool = False # Whether to use dummy temporal dimension in data + adjust_video_noise: bool = False # whether or not adjust video noise accroding to the video length + context_parallel_size: int = 1 # Number of context parallel groups + + # `num_latents_to_drop` is mechanism to satisfy the CP%8==0 and (1I,N*P,1I) latents setup. + # Since our tokenizer is causal and has the `T+1` input frames setup, it makes it + # a little challenging to sample exact number of frames from file, and encode those. + # Instead, we sample as many frame from file, run the tokenizer twice, and discard the second + # chunk's P-latents, ensuring the above two requirements. By default, this flag does not have any effect. + num_latents_to_drop: int = 0 # number of latents to drop + + +@attrs.define(slots=False) +class MultiviewModelConfig(DefaultModelConfig): + n_views: int = 6 + @attrs.define(slots=False) class LatentDiffusionDecoderModelConfig(DefaultModelConfig): diff --git a/cosmos_transfer1/diffusion/config/base/net.py b/cosmos_transfer1/diffusion/config/base/net.py index aef9a9f7..d7a6eab0 100644 --- a/cosmos_transfer1/diffusion/config/base/net.py +++ b/cosmos_transfer1/diffusion/config/base/net.py @@ -34,7 +34,9 @@ pos_emb_learnable=False, pos_emb_interpolation="crop", block_x_format="THWBD", + additional_timestamp_channels=None, affline_emb_norm=True, use_adaln_lora=True, adaln_lora_dim=256, -) + legacy_patch_emb=False, +) \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/config/config.py b/cosmos_transfer1/diffusion/config/config.py index 58446923..d4281327 100644 --- a/cosmos_transfer1/diffusion/config/config.py +++ b/cosmos_transfer1/diffusion/config/config.py @@ -46,5 +46,8 @@ def make_config(): c.job.group = "inference" register_configs() + + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them import_all_modules_from_package("cosmos_transfer1.diffusion.config.inference", reload=True) return c diff --git a/cosmos_transfer1/diffusion/config/config_train.py b/cosmos_transfer1/diffusion/config/config_train.py new file mode 100644 index 00000000..3d2defc8 --- /dev/null +++ b/cosmos_transfer1/diffusion/config/config_train.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List + +import attrs + +from cosmos_transfer1.diffusion.config.transfer.model import CtrlModelConfig +from cosmos_transfer1.checkpointer.ema_fsdp_checkpointer import CheckpointConfig +from cosmos_transfer1.diffusion.config.training.registry_extra import register_configs +from cosmos_transfer1.diffusion.model.model_ctrl import VideoDiffusionModelWithCtrl +from cosmos_transfer1.utils import config +from cosmos_transfer1.utils.config_helper import import_all_modules_from_package +from cosmos_transfer1.utils.lazy_config import PLACEHOLDER +from cosmos_transfer1.utils.lazy_config import LazyCall as L +from cosmos_transfer1.utils.lazy_config import LazyDict +from cosmos_transfer1.utils.trainer import Trainer + + +@attrs.define(slots=False) +class Config(config.Config): + # default config groups that will be used unless overwritten + # see config groups in registry.py + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"data_train": None}, + {"data_val": None}, + {"optimizer": "fusedadamw"}, + {"scheduler": "lambdalinear"}, + {"callbacks": None}, + # + {"net": None}, + {"net_ctrl": None}, + {"hint_key": "control_input_edge"}, + {"conditioner": "ctrlnet_add_fps_image_size_padding_mask"}, + {"pixel_corruptor": None}, + {"fsdp": None}, + {"ema": "power"}, + {"checkpoint": "local"}, + {"ckpt_klass": "multi_rank"}, + {"tokenizer": "vae1"}, + # the list is with order, we need global experiment to be the last one + {"experiment": None}, + ] + ) + model_obj: LazyDict = L(VideoDiffusionModelWithCtrl)( + config=PLACEHOLDER, + ) + checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig) + + + +def make_config(): + c = Config( + model=CtrlModelConfig(), + optimizer=None, + scheduler=None, + dataloader_train=None, + dataloader_val=None, + ) + + c.job.project = "cosmos_transfer1" + c.job.group = "debug" + c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}" + + c.trainer.type = Trainer + # c.trainer.straggler_detection.enabled = False + c.trainer.max_iter = 400_000 + c.trainer.logging_iter = 10 + c.trainer.validation_iter = 100 + c.trainer.run_validation = False + c.trainer.callbacks = None + + register_configs() + import_all_modules_from_package("cosmos_transfer1.diffusion.config.training.experiment", reload=True) + return c diff --git a/cosmos_transfer1/diffusion/config/inference/cosmos-1-diffusion-control2world.py b/cosmos_transfer1/diffusion/config/inference/cosmos-1-diffusion-control2world.py index ab1caaa5..280d41c2 100644 --- a/cosmos_transfer1/diffusion/config/inference/cosmos-1-diffusion-control2world.py +++ b/cosmos_transfer1/diffusion/config/inference/cosmos-1-diffusion-control2world.py @@ -66,7 +66,7 @@ def make_ctrlnet_config_7b( job=dict( group="CTRL_7Bv1_lvg", name=f"CTRL_7Bv1pt3_lvg_tp_121frames_{hint_key}_block{num_control_blocks}", - project="cosmos_ctrlnet1", + project="cosmos_transfer1", ), model=dict( hint_mask=hint_mask, diff --git a/cosmos_transfer1/diffusion/config/registry.py b/cosmos_transfer1/diffusion/config/registry.py index eb2b9cdf..7b264c30 100644 --- a/cosmos_transfer1/diffusion/config/registry.py +++ b/cosmos_transfer1/diffusion/config/registry.py @@ -5,7 +5,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http:./www.apache.org/licenses/LICENSE # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -15,6 +15,7 @@ from hydra.core.config_store import ConfigStore + from cosmos_transfer1.diffusion.config.base.conditioner import ( BaseVideoConditionerConfig, VideoConditionerFpsSizePaddingConfig, @@ -63,6 +64,7 @@ def register_tokenizer(cs): ) + def register_configs(): cs = ConfigStore.instance() diff --git a/cosmos_transfer1/diffusion/config/training/__init__.py b/cosmos_transfer1/diffusion/config/training/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cosmos_transfer1/diffusion/config/training/callbacks.py b/cosmos_transfer1/diffusion/config/training/callbacks.py new file mode 100644 index 00000000..6270c714 --- /dev/null +++ b/cosmos_transfer1/diffusion/config/training/callbacks.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_transfer1.utils.lazy_config import LazyCall as L +from cosmos_transfer1.utils.lazy_config import PLACEHOLDER + +from cosmos_transfer1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_transfer1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_transfer1.diffusion.training.callbacks.grad_clip import GradClip +from cosmos_transfer1.utils.callback import ProgressBarCallback + +BASIC_CALLBACKS = dict( + progress_bar=L(ProgressBarCallback)(), + grad_clip=L(GradClip)(fsdp_enabled=True, model_key="model"), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)(every_n=200, hit_thres=1000), +) \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/config/training/checkpoint.py b/cosmos_transfer1/diffusion/config/training/checkpoint.py new file mode 100644 index 00000000..7248fe53 --- /dev/null +++ b/cosmos_transfer1/diffusion/config/training/checkpoint.py @@ -0,0 +1,27 @@ + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict +from cosmos_transfer1.utils.lazy_config import LazyCall as L +from cosmos_transfer1.checkpointer.fsdp_checkpointer import FSDPCheckpointer +from cosmos_transfer1.checkpointer.multi_rank_checkpointer import MultiRankCheckpointer +from cosmos_transfer1.checkpointer.tp_checkpointer import Checkpointer as TPCheckpointer + + +MULTI_RANK_CHECKPOINTER: Dict[str, str] = L(MultiRankCheckpointer)() +FSDP_CHECKPOINTER: Dict[str, str] = L(FSDPCheckpointer)() +MODEL_PARALLEL_CHECKPOINTER: Dict[str, str] = L(TPCheckpointer)() diff --git a/cosmos_transfer1/diffusion/config/training/ema.py b/cosmos_transfer1/diffusion/config/training/ema.py new file mode 100644 index 00000000..8faacbdd --- /dev/null +++ b/cosmos_transfer1/diffusion/config/training/ema.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_transfer1.utils.ema import EMAModelTracker, PowerEMATracker +from cosmos_transfer1.utils.lazy_config import PLACEHOLDER +from cosmos_transfer1.utils.lazy_config import LazyCall as L +from cosmos_transfer1.utils.lazy_config import LazyDict + +PowerEMAConfig: LazyDict = L(PowerEMATracker.initialize_multi_rank_ema)( + model=PLACEHOLDER, enabled=True, rate=0.10, num=3 +) + +RegEMAConfig: LazyDict = L(EMAModelTracker.initialize_multi_rank_ema)( + model=PLACEHOLDER, enabled=True, rate=0.999, num=1 +) diff --git a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py new file mode 100644 index 00000000..ec4f4854 --- /dev/null +++ b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script will make + register the architecture + training-related configs for all the control modalities (one config per modality). +The configs are registered under the group "experiment" and can be used in training by passing the experiment name as an argument. + +Example usage: + - [dryrun, generate and inspect EdgeControl config] torchrun --nproc_per_node=1 -m cosmos_transfer1.diffusion.training.train --dryrun --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3 + - [real run, 8 gpu, train SegControl] torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3 + - [real run, 8 gpu, train DepthControl] torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_depth_block3 +""" + +from hydra.core.config_store import ConfigStore + +from cosmos_transfer1.utils.lazy_config import LazyCall as L +from cosmos_transfer1.utils.lazy_config import LazyDict +from cosmos_transfer1.diffusion.config.transfer.blurs import random_blur_config +from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS_COMB +from cosmos_transfer1.diffusion.model.model_ctrl import VideoDiffusionModelWithCtrl +from cosmos_transfer1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT + +cs = ConfigStore.instance() + +num_frames = 121 +num_blocks = 28 +num_control_blocks = 3 + +# TODO (qianlim) add data config +def get_data_train_name(hint_key: str) -> str: + pass + +def get_data_val_name(hint_key: str) -> str: + pass + +def make_ctrlnet_config_7b_training( + hint_key: str = "control_input_canny", + num_control_blocks: int = 3, +) -> LazyDict: + + data_train = get_data_train_name(hint_key) + data_val = get_data_val_name(hint_key) + + # Create the complete configuration in one step + config = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /net_ctrl": "faditv2_7b"}, + {"override /conditioner": "ctrlnet_add_fps_image_size_padding_mask"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + # + {"override /hint_key": hint_key}, + {"override /callbacks": "basic"}, + {"override /checkpoint": "local"}, + {"override /ckpt_klass": "multi_rank"}, + # + {"override /data_train": data_train}, + {"override /data_val": data_val}, + "_self_", + ], + job=dict( + group="CTRL_7Bv1_lvg", + name=f"CTRL_7Bv1pt3_lvg_tp_121frames_{hint_key}_block{num_control_blocks}", + project="cosmos_transfer1_posttrain", + ), + optimizer=dict( + lr=2 ** (-14.3), # ~5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + load_path="checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control.pt", # modify as needed. Here we assume post-train our pre-trained VisControl model. + broadcast_via_filesystem=True, + save_iter=1000, + load_training_state=False, + strict_resume=True, + keys_not_to_resume=[], + ), + trainer=dict( + distributed_parallelism="ddp", + logging_iter=200, + max_iter=999_999_999, + ), + model_parallel=dict( + tensor_model_parallel_size=8, + sequence_parallel=True, + ), + model=dict( + fsdp_enabled=False, + context_parallel_size=1, + loss_reduce='mean', + latent_shape=[ + 16, + (num_frames - 1) // 8 + 1, + 88, + 160, + ], + base_load_from=dict( + load_path="checkpoints/nvidia/Cosmos-Transfer1-7B/base_model.pt", # modify as needed. This is the base model (that's frozen during training). + ), + finetune_base_model=False, + hint_mask=[True] * len(CTRL_HINT_KEYS_COMB[hint_key]), + hint_dropout_rate=0.3, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + net=L(VideoExtendGeneralDIT)( + extra_per_block_abs_pos_emb=True, + pos_emb_learnable=True, + extra_per_block_abs_pos_emb_type="learnable", + rope_h_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + rope_w_extrapolation_ratio=1, + ), + adjust_video_noise=True, + net_ctrl=dict( + in_channels=17, + hint_channels=128, + num_blocks=num_blocks, + layer_mask=[True if (i >= num_control_blocks) else False for i in range(num_blocks)], + extra_per_block_abs_pos_emb=True, + pos_emb_learnable=True, + extra_per_block_abs_pos_emb_type="learnable", + ), + ema=dict( + enabled=True, + ), + ), + model_obj=L(VideoDiffusionModelWithCtrl)(), + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_val=dict( + dataset=dict( + resolution="720", + num_video_frames=num_frames, + ), + ), + dataloader_train=dict( + dataloaders=dict( + image_data=dict( + dataloader=dict( + batch_size=1, + dataset=dict( + resolution="720", + blur_config=random_blur_config, + ), + ), + ratio=0, # only use video data for training. + ), + video_data=dict( + dataloader=dict( + batch_size=1, + dataset=dict( + resolution="720", + num_video_frames=num_frames, + blur_config=random_blur_config, + ), + ), + ratio=1, + ), + ), + ), + ) + ) + return config + + +""" +Register configurations +The loop below will register all experiments CTRL_7Bv1pt3_lvg_tp_121frames_control_input_{hint_key_name}_block3 for each hint_key_name +and then in training command, simply need to pass the "experiment" arg to override the configs. See the docstring at top of this script +for an example. +""" +for key in CTRL_HINT_KEYS_COMB.keys(): + config = make_ctrlnet_config_7b_training(hint_key=key, num_control_blocks=num_control_blocks) + cs.store( + group="experiment", + package="_global_", + name=config["job"]["name"], + node=config, + ) \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/config/training/optim.py b/cosmos_transfer1/diffusion/config/training/optim.py new file mode 100644 index 00000000..55558950 --- /dev/null +++ b/cosmos_transfer1/diffusion/config/training/optim.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_transfer1.diffusion.training.functional.lr_scheduler import LambdaLinearScheduler +from cosmos_transfer1.diffusion.training.utils.optim_instantiate import get_base_optimizer +from cosmos_transfer1.utils.lazy_config import PLACEHOLDER +from cosmos_transfer1.utils.lazy_config import LazyCall as L +from cosmos_transfer1.utils.lazy_config import LazyDict + +FusedAdamWConfig: LazyDict = L(get_base_optimizer)( + model=PLACEHOLDER, + lr=1e-4, + weight_decay=0.3, + betas=[0.9, 0.999], + optim_type="fusedadam", + eps=1e-8, + sharding=False, + master_weights=True, + capturable=True, +) + +LambdaLinearSchedulerConfig: LazyDict = L(LambdaLinearScheduler)( + warm_up_steps=[1000], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], +) diff --git a/cosmos_transfer1/diffusion/config/training/registry.py b/cosmos_transfer1/diffusion/config/training/registry.py new file mode 100644 index 00000000..9d5be5c0 --- /dev/null +++ b/cosmos_transfer1/diffusion/config/training/registry.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +''' +Core training related registry. +''' + +from hydra.core.config_store import ConfigStore + + +from cosmos_transfer1.checkpointer.ema_fsdp_checkpointer import CheckpointConfig +from cosmos_transfer1.diffusion.config.training.ema import PowerEMAConfig +from cosmos_transfer1.diffusion.config.training.optim import FusedAdamWConfig, LambdaLinearSchedulerConfig +from cosmos_transfer1.diffusion.config.training.callbacks import BASIC_CALLBACKS +from cosmos_transfer1.diffusion.config.training.checkpoint import ( + FSDP_CHECKPOINTER, + MULTI_RANK_CHECKPOINTER, + MODEL_PARALLEL_CHECKPOINTER, +) + + +def register_ema(cs): + cs.store(group="ema", package="model.ema", name="power", node=PowerEMAConfig) + + +def register_optimizer(cs): + cs.store(group="optimizer", package="optimizer", name="fusedadamw", node=FusedAdamWConfig) + + +def register_scheduler(cs): + cs.store(group="scheduler", package="scheduler", name="lambdalinear", node=LambdaLinearSchedulerConfig) + +def register_callbacks(cs): + cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS) + +def register_checkpoint_credential(cs): + CHECKPOINT_LOCAL = CheckpointConfig( + save_iter=1000, + load_path="", + load_training_state=False, + strict_resume=True, + ) + + cs.store(group="checkpoint", package="checkpoint", name="local", node=CHECKPOINT_LOCAL) + + +def register_checkpointer(cs): + cs.store(group="ckpt_klass", package="checkpoint.type", name="fsdp", node=FSDP_CHECKPOINTER) + cs.store(group="ckpt_klass", package="checkpoint.type", name="multi_rank", node=MULTI_RANK_CHECKPOINTER) + cs.store( + group="ckpt_klass", + package="checkpoint.type", + name="tp", + node=MODEL_PARALLEL_CHECKPOINTER, + ) + + +def register_configs(): + cs = ConfigStore.instance() + + register_optimizer(cs) + register_scheduler(cs) + register_ema(cs) + register_checkpoint_credential(cs) + register_checkpointer(cs) + register_callbacks(cs) \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/config/training/registry_extra.py b/cosmos_transfer1/diffusion/config/training/registry_extra.py new file mode 100644 index 00000000..3769a299 --- /dev/null +++ b/cosmos_transfer1/diffusion/config/training/registry_extra.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +''' +Registry for training experiments, callbacks and data. +''' + +from hydra.core.config_store import ConfigStore + +import cosmos_transfer1.diffusion.config.registry as base_registry +import cosmos_transfer1.diffusion.config.training.registry as base_training_registry +from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS + + +from cosmos_transfer1.diffusion.config.transfer.registry import register_experiment_ctrlnet + +# TODO (qianlim) add config / tutorial for mock data +def register_data_ctrlnet(cs): + pass + +def register_configs(): + cs = ConfigStore.instance() + + # This will register all the basic configs: net, conditioner, tokenizer. + base_registry.register_configs() + + # This will register training configs: optimizer, scheduler, callbacks, etc. + base_training_registry.register_configs() + + # following will register data, experiment, callbacks + # register_data_ctrlnet(cs) # Coming soon + register_experiment_ctrlnet(cs) diff --git a/cosmos_transfer1/diffusion/config/transfer/blurs.py b/cosmos_transfer1/diffusion/config/transfer/blurs.py index 49574d6a..b8f5b7f3 100644 --- a/cosmos_transfer1/diffusion/config/transfer/blurs.py +++ b/cosmos_transfer1/diffusion/config/transfer/blurs.py @@ -154,3 +154,27 @@ class BlurAugmentorConfig: # probabilities from the list of combinations should add up to 1.0 blur_combinations: List[BlurCombinationConfig] = [] downscale_factor: List[int] = [1] + + +# random blur for training the VisControl +random_blur_config = BlurAugmentorConfig( + downscale_factor=list(range(1, 5)), + blur_combinations=[ + BlurCombinationConfig( + blur_types=["bilateral"], + probability=0.5, + bilateral_filter=BilateralFilterConfig(use_random=True), + ), + BlurCombinationConfig( + blur_types=["gaussian"], + probability=0.3, + gaussian_blur=GaussianBlurConfig(use_random=True), + ), + BlurCombinationConfig( + blur_types=["bilateral", "gaussian"], + probability=0.2, + bilateral_filter=BilateralFilterConfig(use_random=True), + gaussian_blur=GaussianBlurConfig(use_random=True), + ), + ], +) diff --git a/cosmos_transfer1/diffusion/config/transfer/conditioner.py b/cosmos_transfer1/diffusion/config/transfer/conditioner.py index 3d725b90..5c130f0d 100644 --- a/cosmos_transfer1/diffusion/config/transfer/conditioner.py +++ b/cosmos_transfer1/diffusion/config/transfer/conditioner.py @@ -57,6 +57,14 @@ "control_input_lidar": [AddControlInputLIDAR], } +# SS=self-supervised +SS_CTRL_HINT_KEYS = [ + "control_input_canny", + "control_input_canny_blur", + "control_input_blur", + "control_input_upscale", +] + BaseVideoConditionerWithCtrlConfig: LazyDict = L(VideoConditionerWithCtrl)( text=TextConfig(), diff --git a/cosmos_transfer1/diffusion/config/transfer/registry.py b/cosmos_transfer1/diffusion/config/transfer/registry.py index 0fff37ba..e7b951ec 100644 --- a/cosmos_transfer1/diffusion/config/transfer/registry.py +++ b/cosmos_transfer1/diffusion/config/transfer/registry.py @@ -52,4 +52,4 @@ def register_experiment_ctrlnet(cs): def register_configs(): cs = ConfigStore.instance() base_registry.register_configs() - register_experiment_ctrlnet(cs) + register_experiment_ctrlnet(cs) \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/functional/batch_ops.py b/cosmos_transfer1/diffusion/functional/batch_ops.py new file mode 100644 index 00000000..a72b2409 --- /dev/null +++ b/cosmos_transfer1/diffusion/functional/batch_ops.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Functions for performing operations with broadcasting to the right axis +# +# Example +# input1: tensor of size (N1, N2) +# input2: tensor of size (N1, N2, N3, N4) +# batch_mul(input1, input2) = input1[:, :, None, None] * input2 +# +# If the common dimensions don't match, we raise an assertion error. + +from torch import Tensor + + +def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: + ndims1 = x.ndim + ndims2 = y.ndim + + common_ndims = min(ndims1, ndims2) + for axis in range(common_ndims): + assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis) + + if ndims1 < ndims2: + x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) + elif ndims2 < ndims1: + y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) + + return x, y + + +def batch_add(x: Tensor, y: Tensor) -> Tensor: + x, y = common_broadcast(x, y) + return x + y + + +def batch_mul(x: Tensor, y: Tensor) -> Tensor: + x, y = common_broadcast(x, y) + return x * y + + +def batch_sub(x: Tensor, y: Tensor) -> Tensor: + x, y = common_broadcast(x, y) + return x - y + + +def batch_div(x: Tensor, y: Tensor) -> Tensor: + x, y = common_broadcast(x, y) + return x / y diff --git a/cosmos_transfer1/diffusion/model/model_ctrl.py b/cosmos_transfer1/diffusion/model/model_ctrl.py index fd07b799..c8f840cc 100644 --- a/cosmos_transfer1/diffusion/model/model_ctrl.py +++ b/cosmos_transfer1/diffusion/model/model_ctrl.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Optional, Tuple, TypeVar, Union +from typing import Callable, Dict, Literal, Optional, Tuple, TypeVar, Union import torch from einops import rearrange from megatron.core import parallel_state from torch import Tensor -from cosmos_transfer1.diffusion.conditioner import VideoConditionerWithCtrl +from cosmos_transfer1.diffusion.conditioner import VideoConditionerWithCtrl, CosmosCondition from cosmos_transfer1.diffusion.inference.inference_utils import merge_patches_into_video, split_video_into_patches from cosmos_transfer1.diffusion.model.model_t2w import DiffusionT2WModel, broadcast_condition from cosmos_transfer1.diffusion.model.model_v2w import DiffusionV2WModel @@ -30,7 +30,7 @@ T = TypeVar("T") IS_PREPROCESSED_KEY = "is_preprocessed" - +COMMON_SOLVER_OPTIONS = Literal["2ab", "2mid", "1euler"] class VideoDiffusionModelWithCtrl(DiffusionV2WModel): def build_model(self) -> torch.nn.ModuleDict: @@ -163,6 +163,34 @@ def encode_latent(self, data_batch: dict, cond_mask: list = []) -> torch.Tensor: latent = torch.cat(latent, dim=1) return latent + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, Tensor], + x0_from_data_batch: Tensor, + x0: Tensor, + condition: CosmosCondition, + epsilon: Tensor, + sigma: Tensor, + ): + if self.is_image_batch(data_batch): + # Turn off CP + self.net.disable_context_parallel() + self.base_net.disable_context_parallel() + else: + if parallel_state.get_context_parallel_world_size() > 1: + # Turn on CP + cp_group = parallel_state.get_context_parallel_group() + self.net.enable_context_parallel(cp_group) + self.base_net.enable_context_parallel(cp_group) + log.debug("[CP] Split hint_input") + hint_key = self.config.hint_key["hint_key"] + x_hint_raw = getattr(condition, hint_key) + x_hint = split_inputs_cp(x=x_hint_raw, seq_dim=2, cp_group=self.net.cp_group) + setattr(condition, hint_key, x_hint) + return super().compute_loss_with_epsilon_and_sigma( + data_batch, x0_from_data_batch, x0, condition, epsilon, sigma + ) + def get_x0_fn_from_batch( self, data_batch: Dict, @@ -371,6 +399,258 @@ def generate_samples_from_batch( return samples + def get_patch_based_x0_fn( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + condition_latent: torch.Tensor = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + target_h: int = 2112, + target_w: int = 3840, + patch_h: int = 704, + patch_w: int = 1280, + seed_inference: int = 1, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + The function will split the input into patches, run inference on each patch, then stitch them together. + + Additional args to original function: + target_h (int): final stitched video height + target_w (int): final stitched video width + patch_h (int): video patch height for each network inference + patch_w (int): video patch width for each network inference + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 prediction + """ + assert patch_h <= target_h and patch_w <= target_w + # data_batch should be the one processed by self.get_data_and_condition + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + if hasattr(self, "is_extend_model") and self.is_extend_model: + # Add conditions for long video generation. + if condition_latent is None: + condition_latent = torch.zeros(data_batch["latent_hint"].shape, **self.tensor_kwargs) + num_condition_t = 0 + condition_video_augment_sigma_in_inference = 1000 + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent[:1], condition, num_condition_t + ) + uncondition.video_cond_bool = True # Not do cfg on condition frames + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent[:1], uncondition, num_condition_t + ) + # Add extra conditions for ctrlnet. + latent_hint = data_batch["latent_hint"] + hint_key = data_batch["hint_key"] + setattr(condition, hint_key, latent_hint) + if "use_none_hint" in data_batch and data_batch["use_none_hint"]: + setattr(uncondition, hint_key, None) + else: + setattr(uncondition, hint_key, latent_hint) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized() and not self.is_image_batch(data_batch): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + cp_group = parallel_state.get_context_parallel_group() + latent_hint = getattr(condition, hint_key) + latent_hint = split_inputs_cp(latent_hint, seq_dim=2, cp_group=cp_group) + + setattr(condition, "base_model", self.model.base_model) + setattr(uncondition, "base_model", self.model.base_model) + if hasattr(self, "hint_encoders"): + self.model.net.hint_encoders = self.hint_encoders + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor): + w, h = target_w, target_h + n_img_w = (w - 1) // patch_w + 1 + n_img_h = (h - 1) // patch_h + 1 + + overlap_size_w = overlap_size_h = 0 + if n_img_w > 1: + overlap_size_w = (n_img_w * patch_w - w) // (n_img_w - 1) + assert n_img_w * patch_w - overlap_size_w * (n_img_w - 1) == w + if n_img_h > 1: + overlap_size_h = (n_img_h * patch_h - h) // (n_img_h - 1) + assert n_img_h * patch_h - overlap_size_h * (n_img_h - 1) == h + + batch_images = noise_x + batch_sigma = sigma + output = [] + for idx, cur_images in enumerate(batch_images): + noise_x = cur_images.unsqueeze(0) + sigma = batch_sigma[idx : idx + 1] + condition.gt_latent = condition_latent[idx : idx + 1] + uncondition.gt_latent = condition_latent[idx : idx + 1] + setattr(condition, hint_key, latent_hint[idx : idx + 1]) + if getattr(uncondition, hint_key) is not None: + setattr(uncondition, hint_key, latent_hint[idx : idx + 1]) + + if self.is_image_batch(data_batch) or not issubclass(base_class, ExtendVideoDiffusionModel): + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + else: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + output.append(x0) + output = rearrange(torch.stack(output), "(n t) b ... -> (b n t) ...", n=n_img_h, t=n_img_w) # 8x3xhxw + final_output = merge_patches_into_video(output, overlap_size_h, overlap_size_w, n_img_h, n_img_w) + final_output = split_video_into_patches(final_output, patch_h, patch_w) + return final_output + + return x0_fn + + def generate_samples_from_patches( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Union[torch.Tensor, None] = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + x_sigma_max: Optional[torch.Tensor] = None, + sigma_max: float | None = None, + target_h: int = 2112, + target_w: int = 3840, + patch_h: int = 704, + patch_w: int = 1280, + ) -> Tensor: + """ + Generate samples from the batch using patch-based inference. During each denoising step, it will denoise each patch + separately then average the overlapping regions. + + Additional args to original function: + target_h (int): final stitched video height + target_w (int): final stitched video width + patch_h (int): video patch height for each network inference + patch_w (int): video patch width for each network inference + """ + assert patch_h <= target_h and patch_w <= target_w + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if is_image_batch: + log.debug("image batch, call base model generate_samples_from_batch") + return super().generate_samples_from_batch( + data_batch, + guidance=guidance, + seed=seed, + state_shape=state_shape, + n_sample=n_sample, + is_negative_prompt=is_negative_prompt, + num_steps=num_steps, + ) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + else: + log.debug(f"Default Video state shape is used. {self.state_shape}") + state_shape = self.state_shape + + x0_fn = self.get_patch_based_x0_fn( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + target_h=target_h, + target_w=target_w, + patch_h=patch_h, + patch_w=patch_w, + seed_inference=seed, + ) + + if sigma_max is None: + sigma_max = self.sde.sigma_max + + if x_sigma_max is None: + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * sigma_max + ) + + if self.net.is_context_parallel_enabled: + x_sigma_max = broadcast(x_sigma_max, to_tp=True, to_cp=True) + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + + samples = self.sampler( + x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max, solver_option=solver_option + ) + if self.net.is_context_parallel_enabled: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + return samples + + @torch.no_grad() + def validation_step( + self, data: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + save generated videos + """ + raw_data, x0, condition = self.get_data_and_condition(data) + guidance = data["guidance"] + sigma_max = data["sigma_max"] + is_negative_prompt = data["is_negative_prompt"] + data = misc.to(data, **self.tensor_kwargs) + x_sigma_max = None + if sigma_max is not None: + x_sigma_max = self.get_x_from_clean(x0, sigma_max) + sample = self.generate_samples_from_batch( + data, + guidance=guidance, + # make sure no mismatch and also works for cp + state_shape=x0.shape[1:], + n_sample=x0.shape[0], + x_sigma_max=x_sigma_max, + sigma_max=sigma_max, + is_negative_prompt=is_negative_prompt, + ) + sample = self.decode(sample) + gt = raw_data + hint = data[data["hint_key"]][:, :3] + result = torch.cat([hint, sample], dim=3) + gt = torch.cat([hint, gt], dim=3) + caption = data["ai_caption"] + return {"gt": gt, "result": result, "caption": caption}, torch.tensor([0]).to(**self.tensor_kwargs) + + class VideoDiffusionT2VModelWithCtrl(DiffusionT2WModel): def build_model(self) -> torch.nn.ModuleDict: diff --git a/cosmos_transfer1/diffusion/model/model_v2w.py b/cosmos_transfer1/diffusion/model/model_v2w.py index 21e6642d..84662013 100644 --- a/cosmos_transfer1/diffusion/model/model_v2w.py +++ b/cosmos_transfer1/diffusion/model/model_v2w.py @@ -17,17 +17,17 @@ from typing import Callable, Dict, Optional, Tuple, Union import torch +from einops import rearrange from megatron.core import parallel_state from torch import Tensor from cosmos_transfer1.diffusion.conditioner import VideoExtendCondition from cosmos_transfer1.diffusion.config.base.conditioner import VideoCondBoolConfig from cosmos_transfer1.diffusion.diffusion.functional.batch_ops import batch_mul -from cosmos_transfer1.diffusion.model.model_t2w import DiffusionT2WModel -from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_transfer1.diffusion.model.model_t2w import DiffusionT2WModel, broadcast_condition +from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp, broadcast from cosmos_transfer1.utils import log, misc - @dataclass class VideoDenoisePrediction: x0: torch.Tensor # clean data prediction @@ -328,15 +328,37 @@ def add_condition_video_indicator_and_video_input_mask( condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( latent_dtype ) # 1 for condition region - - # Only in inference to decide the condition region - assert num_condition_t is not None, "num_condition_t should be provided" - assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" - log.debug( - f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" - ) - condition_video_indicator[:, :, :num_condition_t] += 1.0 - + if self.config.conditioner.video_cond_bool.condition_location == "first_n": + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + elif self.config.conditioner.video_cond_bool.condition_location == "first_random_n": + # Only in training + num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max + assert ( + num_condition_t_max <= T + ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" + assert num_condition_t_max >= self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min + num_condition_t = torch.randint( + self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min, + num_condition_t_max + 1, + (1,), + ).item() + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + elif self.config.conditioner.video_cond_bool.condition_location == "random": + # Only in training + condition_rate = self.config.conditioner.video_cond_bool.random_conditon_rate + flag = torch.ones(1, 1, T, 1, 1, device=latent_state.device).type(latent_dtype) * condition_rate + condition_video_indicator = torch.bernoulli(flag).type(latent_dtype).to(latent_state.device) + else: + raise NotImplementedError( + f"condition_location {self.config.conditioner.video_cond_bool.condition_location} not implemented; training={self.training}" + ) condition.gt_latent = latent_state condition.condition_video_indicator = condition_video_indicator @@ -355,4 +377,59 @@ def add_condition_video_indicator_and_video_input_mask( else: # Unconditional case, use for cfg condition.condition_video_input_mask = zeros_padding + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + return condition + + + def add_condition_pose(self, data_batch: Dict, condition: VideoExtendCondition) -> VideoExtendCondition: + """Add pose condition to the condition object. For camera control model + Args: + data_batch (Dict): data batch, with key "plucker_embeddings", in shape B,T,C,H,W + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + assert ( + "plucker_embeddings" in data_batch or "plucker_embeddings_downsample" in data_batch.keys() + ), f"plucker_embeddings should be in data_batch. only find {data_batch.keys()}" + plucker_embeddings = ( + data_batch["plucker_embeddings"] + if "plucker_embeddings_downsample" not in data_batch.keys() + else data_batch["plucker_embeddings_downsample"] + ) + condition.condition_video_pose = rearrange(plucker_embeddings, "b t c h w -> b c t h w").contiguous() + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition + + def sample_tokens_start_from_p_or_i(self, latent_state: torch.Tensor) -> torch.Tensor: + """Sample the PPP... from the IPPP... sequence, only for video sequence + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + Returns: + torch.Tensor: sampled PPP tensor in shape B,C,T,H,W + """ + B, C, T, H, W = latent_state.shape + latent_dtype = latent_state.dtype + T_target = self.state_shape[1] + latent_state_sample = torch.zeros((B, C, T_target, H, W), dtype=latent_dtype, device=latent_state.device) + t_start = torch.randint(0, T - T_target + 1, (1,)) + # broadcast to other device + latent_state_sample = latent_state[:, :, t_start : t_start + T_target].contiguous() + if parallel_state.is_initialized(): + latent_state_sample = broadcast(latent_state_sample, to_tp=True, to_cp=True) + + return latent_state_sample diff --git a/cosmos_transfer1/diffusion/module/blocks.py b/cosmos_transfer1/diffusion/module/blocks.py index 9ecbe337..d09f2470 100644 --- a/cosmos_transfer1/diffusion/module/blocks.py +++ b/cosmos_transfer1/diffusion/module/blocks.py @@ -30,6 +30,55 @@ def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) +class SDXLTimesteps(nn.Module): + def __init__(self, num_channels: int = 320): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps): + in_dype = timesteps.dtype + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return emb.to(in_dype) + + +class SDXLTimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False): + super().__init__() + log.critical( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False) + else: + self.linear_2 = nn.Linear(out_features, out_features, bias=True) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_3D = emb + emb_B_D = sample + else: + emb_B_D = emb + adaln_lora_B_3D = None + + return emb_B_D, adaln_lora_B_3D + class Timesteps(nn.Module): def __init__(self, num_channels): super().__init__() diff --git a/cosmos_transfer1/diffusion/module/parallel.py b/cosmos_transfer1/diffusion/module/parallel.py index e08356c9..31a1c402 100644 --- a/cosmos_transfer1/diffusion/module/parallel.py +++ b/cosmos_transfer1/diffusion/module/parallel.py @@ -160,4 +160,4 @@ def _robust_broadcast(tensor: torch.Tensor, src: int, pg, is_check_shape: bool = # Now broadcast the tensor data torch.distributed.broadcast(tensor, src, group=pg) - return tensor + return tensor \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/module/position_embedding.py b/cosmos_transfer1/diffusion/module/position_embedding.py index c10e6e16..372f7777 100644 --- a/cosmos_transfer1/diffusion/module/position_embedding.py +++ b/cosmos_transfer1/diffusion/module/position_embedding.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Literal, Optional +import numpy as np import torch + +import torch.nn.functional as F from einops import rearrange, repeat from torch import nn from torch.distributed import ProcessGroup, get_process_group_ranks @@ -25,6 +28,62 @@ from cosmos_transfer1.diffusion.module.timm import trunc_normal_ +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_3d_sincos_pos_embed( + embed_dim, + grid_size_h, + grid_size_w, + grid_size_t, + spatial_interpolation_scale, + temporal_interpolation_scale, + concat=True, +): + grid_h = np.arange(grid_size_h, dtype=np.float32) / spatial_interpolation_scale + grid_w = np.arange(grid_size_w, dtype=np.float32) / spatial_interpolation_scale + grid_t = np.arange(grid_size_t, dtype=np.float32) / temporal_interpolation_scale + + grid = np.meshgrid(grid_w, grid_h, grid_t, indexing="ij") + grid = np.stack(grid, axis=0) + grid = grid.reshape(3, 1, grid_size_h, grid_size_w, grid_size_t) + + if concat: + per_axis = embed_dim // 3 + per_axis = (per_axis // 2) * 2 # make it even (for sin/cos split) + dim_h, dim_w = per_axis, per_axis + dim_t = embed_dim - dim_h - dim_w + emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, grid[0]) # (H*W, D/3) + emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, grid[1]) # (H*W, D/3) + emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, grid[2]) # (H*W, D/3) + + return np.concatenate([emb_h, emb_w, emb_t], axis=1) # (H*W*T, D) + else: + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim, grid[0]) # (H*W) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim, grid[1]) # (H*W) + emb_t = get_1d_sincos_pos_embed_from_grid(embed_dim, grid[2]) # (H*W) + + return emb_h + emb_w + emb_t # (H*W*T, D) + + class VideoPositionEmb(nn.Module): def __init__(self): super().__init__() @@ -60,6 +119,37 @@ def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) raise NotImplementedError + + +class VideoRopePositionEmb(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.register_buffer("seq", torch.arange(len_h * len_w * len_t, dtype=torch.float)) + + self.register_buffer( + "dim_range", torch.arange(0, head_dim, 2)[: (head_dim // 2)].float().cuda() / head_dim, persistent=False + ) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], ntk_factor: float = 1.0): + theta = 10000.0 * ntk_factor + + # original_dtype = self.dim_range.dtype + freq = 1.0 / (theta ** self.dim_range.float()) + _, T, H, W, _ = B_T_H_W_C + length = T * H * W + emb_L_D = torch.outer(self.seq[:length], freq) + return rearrange(torch.cat([emb_L_D, emb_L_D], dim=-1), "l d -> l 1 1 d").float() + + class VideoRopePosition3DEmb(VideoPositionEmb): def __init__( self, @@ -209,3 +299,346 @@ def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) raise ValueError(f"Unknown interpolation method {self.interpolation}") return normalize(emb, dim=-1, eps=1e-6) + + + +class LearnableEmb3D(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + interpolation: str = "crop", + is_learnable: bool = True, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs # unused + super().__init__() + assert is_learnable is True + self.interpolation = interpolation + self.pos_embed = nn.Parameter(torch.zeros(1, len_t, len_h, len_w, model_channels)) + trunc_normal_(self.pos_embed, std=0.02) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + if self.interpolation == "crop": + return self.pos_embed[:, :T, :H, :W] + if self.interpolation == "resize": + return rearrange( + F.interpolate( + rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), + size=(H, W, T), + mode="linear", + align_corners=False, + ), + "1 c h w t -> 1 t h w c", + ) + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + +class LearnableEmb3D_FPS_Aware(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + min_fps: int, # 1 for getty video + max_fps: int, # 120 for getty video + interpolation: str = "crop", + is_learnable: bool = True, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + assert is_learnable is True + self.interpolation = interpolation + self.max_fps = max_fps + self.min_fps = min_fps + + if self.interpolation == "crop": + self.pos_embed = nn.Parameter( + torch.zeros(1, len_t * int(max_fps / min_fps), len_h, len_w, model_channels) + ) # should be max_seq_length * (max_fps / min_fps) + elif self.interpolation == "resize": + self.pos_embed = nn.Parameter( + torch.zeros(1, len_t, len_h, len_w, model_channels) + ) # time embedding based min fps + else: + ValueError(f"Unknown interpolation method {self.interpolation}") + + trunc_normal_(self.pos_embed, std=0.02) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + + if self.interpolation == "crop": + if T > 1: + return torch.cat( + [ + self.pos_embed[:, : (int(self.max_fps / curr_fps) * T) : int(self.max_fps / curr_fps), :H, :W] + for curr_fps in fps + ], + 0, + ) + else: + return self.pos_embed[:, :T, :H, :W] # image model + elif self.interpolation == "resize": + if T > 1: + return torch.cat( + [ + rearrange( + F.interpolate( + rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), + size=(H, W, T * int(curr_fps / self.min_fps)), + mode="trilinear", + align_corners=True, # important: align corner need to be true + )[:, :, :H, :W, :T], + "1 c h w t -> 1 t h w c", + ) + for curr_fps in fps + ], + 0, + ) + else: + # grab self.pos_embed at time step 0 and resize spatially + return rearrange( + F.interpolate( + rearrange(self.pos_embed[:, 0, ::], "1 h w c -> 1 c h w"), + size=(H, W), + mode="bilinear", + align_corners=True, + ), + "1 c h w -> 1 h w c", + ) + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + + +class SinCosPosEmbAxis(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + dim = model_channels + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + + # rescale pos id is equivalent to rescale frequency + emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(len_h) * 1.0 / h_extrapolation_ratio) + emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(len_w) * 1.0 / w_extrapolation_ratio) + emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(len_t) * 1.0 / t_extrapolation_ratio) + + self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).float(), persistent=False) + self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).float(), persistent=False) + self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).float(), persistent=False) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H] + emb_w_W = self.pos_emb_w[:W] + emb_t_T = self.pos_emb_t[:T] + emb = torch.cat( + [ + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W), + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W), + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H), + ], + dim=-1, + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + return emb + + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + + +class SinCosPosEmb_FPS_Aware(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + min_fps: int, # 1 for getty video + max_fps: int, # 120 for getty video + is_learnable: bool = False, + interpolation: str = "crop", + spatial_interpolation_scale=1.0, + temporal_interpolation_scale=1.0, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs # unused + super().__init__() + self.interpolation = interpolation + self.max_fps = max_fps + self.min_fps = min_fps + if self.interpolation == "crop": + param = get_3d_sincos_pos_embed( + model_channels, + len_h, + len_w, + len_t * int(max_fps / min_fps), + spatial_interpolation_scale, + temporal_interpolation_scale, + ) # should be max_seq_length * (max_fps / min_fps) + elif self.interpolation == "resize": + param = get_3d_sincos_pos_embed( + model_channels, len_h, len_w, len_t, spatial_interpolation_scale, temporal_interpolation_scale + ) # time embedding based min fps + else: + ValueError(f"Unknown interpolation method {self.interpolation}") + param = rearrange(param, "(h w t) c -> 1 t h w c", h=len_h, w=len_w) + if is_learnable: + self.pos_embed = nn.Parameter( + torch.from_numpy(param).float(), + ) + else: + self.register_buffer("pos_embed", torch.from_numpy(param).float(), persistent=False) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + + if self.interpolation == "crop": + if T > 1: + return torch.cat( + [ + self.pos_embed[:, : (int(self.max_fps / curr_fps) * T) : int(self.max_fps / curr_fps), :H, :W] + for curr_fps in fps + ], + 0, + ) + else: + return self.pos_embed[:, :T, :H, :W] # image model + elif self.interpolation == "resize": + if T > 1: + return torch.cat( + [ + rearrange( + F.interpolate( + rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), + size=(H, W, T * int(curr_fps / self.min_fps)), + mode="trilinear", + align_corners=True, # important: align corner need to be true + )[:, :, :H, :W, :T], + "1 c h w t -> 1 t h w c", + ) + for curr_fps in fps + ], + 0, + ) + else: + # grab self.pos_embed at time step 0 and resize spatially + return rearrange( + F.interpolate( + rearrange(self.pos_embed[:, 0, ::], "1 h w c -> 1 c h w"), + size=(H, W), + mode="bilinear", + align_corners=True, + ), + "1 c h w -> 1 h w c", + ) + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + +class SinCosPosEmb(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + is_learnable: bool = False, + interpolation: Literal["crop", "resize", "crop_resize"] = "crop", + spatial_interpolation_scale=1.0, + temporal_interpolation_scale=1.0, + init_length_for_resize: int = 16, + **kwargs, + ): + """ + Args: + interpolation (str): "crop", "resize", "crop_resize". "crop" means we crop the positional embedding to the length of the input sequence. "resize" means we resize the positional embedding to the length of the input sequence. "crop_resize" (inference only) means we first crop the positional embedding to init_length_for_resize, then resize it to the length of the input sequence. + init_length_for_resize (int): used when interpolation is "crop_resize", where we "resize" embedding during inference for model trained with "crop". We first "crop" the pos_embed to this length (used during training), then run the "resize", default 16 + """ + del kwargs # unused + super().__init__() + self.interpolation = interpolation + self.init_length_for_resize = init_length_for_resize + param = get_3d_sincos_pos_embed( + model_channels, len_h, len_w, len_t, spatial_interpolation_scale, temporal_interpolation_scale + ) + param = rearrange(param, "(h w t) c -> 1 t h w c", h=len_h, w=len_w) + if is_learnable: + self.pos_embed = nn.Parameter( + torch.from_numpy(param).float(), + ) + else: + self.register_buffer("pos_embed", torch.from_numpy(param).float(), persistent=False) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + if self.interpolation == "crop": + return self.pos_embed[:, :T, :H, :W] + if self.interpolation == "resize": + return rearrange( + F.interpolate( + rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), + size=(H, W, T), + mode="linear", + align_corners=False, + ), + "1 c h w t -> 1 t h w c", + ) + if self.interpolation == "crop_resize": + pos_embed_crop = self.pos_embed[:, : self.init_length_for_resize, :H, :W] # B,T,H,W,C + _, t, h, w, c = pos_embed_crop.shape + + pos_embed_crop_resize_t = rearrange( + F.interpolate( + rearrange(pos_embed_crop, "1 t h w c -> 1 (c h w) t"), + size=(T), + mode="linear", + ), + "1 (c h w) t -> 1 t h w c", + c=c, + h=h, + w=w, + ) + pos_embed_crop_resize = rearrange( + F.interpolate( + rearrange(pos_embed_crop_resize_t, "1 t h w c -> 1 (c t) h w"), + size=(H, W), + mode="bilinear", + ), + "1 (c t) h w -> 1 t h w c", + c=c, + ) + return pos_embed_crop_resize + + raise ValueError(f"Unknown interpolation method {self.interpolation}") \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/networks/general_dit.py b/cosmos_transfer1/diffusion/networks/general_dit.py index 63cf9103..06c00fa1 100644 --- a/cosmos_transfer1/diffusion/networks/general_dit.py +++ b/cosmos_transfer1/diffusion/networks/general_dit.py @@ -15,12 +15,30 @@ """ A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. +It allows us easy to switch building blocks used and their order. Its instantiation includes +* transformer on fully flattened tokens +* factored spatial and temporal attention +* factored non-overlap spatial and temporal attention +* mixing of above attention types + +Limitations: + +* In favor of simplicity and cleanness, many ops are not fused and we can do better +* such as combining mutiple adaln MLPs into one inside one transformer block. +* we use reshape heavily, which may be not efficient when its occurs unnecessary CUDA memory copy + +Purpose: +* A prototype for testing different attention types and their combinations +* Idealy, we want to know where we should allocate our resources / FLOPS / memory via extensive empirical studies """ + +from collections.abc import Container from typing import List, Optional, Tuple import torch from einops import rearrange +from megatron.core import parallel_state from torch import nn from torch.distributed import ProcessGroup, get_process_group_ranks from torchvision import transforms @@ -28,57 +46,96 @@ from cosmos_transfer1.diffusion.conditioner import DataType from cosmos_transfer1.diffusion.module.attention import get_normalization from cosmos_transfer1.diffusion.module.blocks import ( + DITBuildingBlock, FinalLayer, GeneralDITTransformerBlock, PatchEmbed, - TimestepEmbedding, - Timesteps, + SDXLTimestepEmbedding, + SDXLTimesteps, +) +from cosmos_transfer1.diffusion.module.position_embedding import ( + LearnableEmb3D, + LearnableEmb3D_FPS_Aware, + LearnablePosEmbAxis, + SinCosPosEmb, + SinCosPosEmb_FPS_Aware, + SinCosPosEmbAxis, + VideoRopePosition3DEmb, + VideoRopePositionEmb, ) -from cosmos_transfer1.diffusion.module.position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb +from cosmos_transfer1.diffusion.training.tensor_parallel import gather_along_first_dim, scatter_along_first_dim from cosmos_transfer1.utils import log class GeneralDIT(nn.Module): """ A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. - - Args: + Attributes: max_img_h (int): Maximum height of the input images. max_img_w (int): Maximum width of the input images. max_frames (int): Maximum number of frames in the video sequence. in_channels (int): Number of input channels (e.g., RGB channels for color images). out_channels (int): Number of output channels. - patch_spatial (tuple): Spatial resolution of patches for input processing. + patch_spatial (tuple of int): Spatial resolution of patches for input processing. patch_temporal (int): Temporal resolution of patches for input processing. concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. - block_config (str): Configuration of the transformer block. See Notes for supported block types. + block_config (str): Configuration of the transformer block, e.g., 'FA-CA-MLP', means + full attention, cross attention, and MLP in sequence in one transformer block. model_channels (int): Base number of channels used throughout the model. - num_blocks (int): Number of transformer blocks. - num_heads (int): Number of heads in the multi-head attention layers. - mlp_ratio (float): Expansion ratio for MLP blocks. - block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD'). - crossattn_emb_channels (int): Number of embedding channels for cross-attention. - use_cross_attn_mask (bool): Whether to use mask in cross-attention. - pos_emb_cls (str): Type of positional embeddings. - pos_emb_learnable (bool): Whether positional embeddings are learnable. - pos_emb_interpolation (str): Method for interpolating positional embeddings. - affline_emb_norm (bool): Whether to normalize affine embeddings. - use_adaln_lora (bool): Whether to use AdaLN-LoRA. - adaln_lora_dim (int): Dimension for AdaLN-LoRA. - rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. - rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. - rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. - extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. - extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings. - extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. - extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. - extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. - - Notes: - Supported block types in block_config: - * cross_attn, ca: Cross attention - * full_attn: Full attention on all flattened tokens - * mlp, ff: Feed forward block + num_blocks (int): Number of residual blocks per resolution in the transformer. + num_heads (int): Number of heads in the multi-head self-attention layers. + spatial_attn_win_size (int): Window size for the spatial attention mechanism. + temporal_attn_win_size (int): Window size for the temporal attention mechanism. + mlp_ratio (float): Expansion ratio for the MLP (multi-layer perceptron) blocks in the transformer. + use_memory_save (bool): If True, utilizes checkpointing to reduce memory usage during training. (Deprecated) + use_checkpoint (bool): If True, utilizes checkpointing to reduce memory usage during training for all blocks. + crossattn_emb_channels (int): Number of embedding channels used in the cross-attention layers. + use_cross_attn_mask (bool): If True, applies a mask during cross-attention operations to manage sequence alignment. + pos_emb_cls (str): Type of positional embeddings used ('sincos' for sinusoidal or other types). + pos_emb_learnable (bool): Specifies if positional embeddings are learnable. + pos_emb_interpolation (str): Method used for interpolating positional embeddings, e.g., 'crop' for cropping adjustments. + block_x_format (str, optional): The format of the input tensor for the transformer block. Defaults to "BTHWD". Only support 'BTHWD' and 'THWBD'. + legacy_patch_emb (bool): If True, applies 3D convolutional layers for video inputs, otherwise, use Linear! This is for backward compatibility. + rope_h_extrapolation_ratio (float): Ratio of the height extrapolation for the rope positional embedding. + rope_w_extrapolation_ratio (float): Ratio of the width extrapolation for the rope positional embedding. + rope_t_extrapolation_ratio (float): Ratio of the temporal extrapolation for the rope positional embedding. + Note: + block_config support block type: + * spatial_sa, ssa: spatial self attention + * temporal_sa, tsa: temporal self attention + * cross_attn, ca: cross attention + * full_attn: full attention on all flatten tokens + * mlp, ff: feed forward block + * use '-' to separate different building blocks, e.g., 'FA-CA-MLP' means full attention, cross attention, and MLP in sequence in one transformer block. + + Example: + >>> # full attention, cross attention, and MLP + >>> option1_block_config = 'FA-CA-MLP' + >>> model_1 = GeneralDIT( + max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, + patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, + num_heads=16, mlp_ratio=4.0, + spatial_attn_win_size=1, temporal_attn_win_size=1, + block_config=option1_block_config + ) + >>> option2_block_config = 'SSA-CA-MLP-TSA-CA-MLP' + >>> model_2 = GeneralDIT( + max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, + patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, + num_heads=16, mlp_ratio=4.0, + spatial_attn_win_size=1, temporal_attn_win_size=1, + block_config=option2_block_config + ) + >>> # option3 model + >>> model_3 = GeneralDIT( + max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, + patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, + num_heads=16, mlp_ratio=4.0, + spatial_attn_win_size=1, temporal_attn_win_size=2, + block_config=option2_block_config + ) + >>> # Process input tensor through the model + >>> output = model(input_tensor) """ def __init__( @@ -96,7 +153,13 @@ def __init__( model_channels: int = 768, num_blocks: int = 10, num_heads: int = 16, + window_block_indexes: list = [], # index for window attention block + window_sizes: list = [], # window size for window attention block in the order of T, H, W + spatial_attn_win_size: int = 1, + temporal_attn_win_size: int = 1, mlp_ratio: float = 4.0, + use_memory_save: bool = False, + use_checkpoint: bool = False, block_x_format: str = "BTHWD", # cross attention settings crossattn_emb_channels: int = 1024, @@ -105,9 +168,14 @@ def __init__( pos_emb_cls: str = "sincos", pos_emb_learnable: bool = False, pos_emb_interpolation: str = "crop", + min_fps: int = 1, # 1 for getty video + max_fps: int = 30, # 120 for getty video but let's use 30 + additional_timestamp_channels: dict = None, # Follow SDXL, in format of {condition_name : dimension} affline_emb_norm: bool = False, # whether or not to normalize the affine embedding use_adaln_lora: bool = False, adaln_lora_dim: int = 256, + layer_mask: list = None, # whether or not a layer is used. For controlnet encoder + legacy_patch_emb: bool = True, rope_h_extrapolation_ratio: float = 1.0, rope_w_extrapolation_ratio: float = 1.0, rope_t_extrapolation_ratio: float = 1.0, @@ -116,7 +184,6 @@ def __init__( extra_h_extrapolation_ratio: float = 1.0, extra_w_extrapolation_ratio: float = 1.0, extra_t_extrapolation_ratio: float = 1.0, - layer_mask: list = None, # whether or not a layer is used. For controlnet encoder ) -> None: super().__init__() self.max_img_h = max_img_h @@ -135,7 +202,11 @@ def __init__( self.pos_emb_cls = pos_emb_cls self.pos_emb_learnable = pos_emb_learnable self.pos_emb_interpolation = pos_emb_interpolation + self.min_fps = min_fps + self.max_fps = max_fps + self.additional_timestamp_channels = additional_timestamp_channels self.affline_emb_norm = affline_emb_norm + self.legacy_patch_emb = legacy_patch_emb self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio @@ -148,15 +219,23 @@ def __init__( self.build_patch_embed() self.build_pos_embed() self.cp_group = None + self.sequence_parallel = getattr(parallel_state, "sequence_parallel", False) self.block_x_format = block_x_format self.use_adaln_lora = use_adaln_lora self.adaln_lora_dim = adaln_lora_dim self.t_embedder = nn.Sequential( - Timesteps(model_channels), - TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), + SDXLTimesteps(model_channels), + SDXLTimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), ) self.blocks = nn.ModuleDict() + self.block_config = block_config + self.use_memory_save = use_memory_save + self.use_checkpoint = use_checkpoint + + assert ( + len(window_block_indexes) == 0 or block_config == "FA-CA-MLP" + ), "Block config must be FA-CA-MLP if using a combination of window attention and global attention" layer_mask = [False] * num_blocks if layer_mask is None else layer_mask assert ( @@ -170,21 +249,33 @@ def __init__( context_dim=crossattn_emb_channels, num_heads=num_heads, block_config=block_config, + window_sizes=( + window_sizes if idx in window_block_indexes else [] + ), # There will be bug if using "WA-CA-MLP" mlp_ratio=mlp_ratio, + spatial_attn_win_size=spatial_attn_win_size, + temporal_attn_win_size=temporal_attn_win_size, x_format=self.block_x_format, use_adaln_lora=use_adaln_lora, adaln_lora_dim=adaln_lora_dim, + use_checkpoint=use_checkpoint, ) self.build_decode_head() + self.build_additional_timestamp_embedder() if self.affline_emb_norm: - log.debug("Building affine embedding normalization layer") + log.critical("Building affine embedding normalization layer") self.affline_norm = get_normalization("R", model_channels) else: self.affline_norm = nn.Identity() - self.initialize_weights() + self.init_weights() + + if self.use_memory_save: + log.critical("Using checkpointing to save memory! only verified in 14B base model training!") + for block in self.blocks.values(): + block.set_memory_save() - def initialize_weights(self): + def init_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): @@ -209,6 +300,50 @@ def _basic_init(module): if block.adaLN_modulation[-1].bias is not None: nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + # Tensor parallel + if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: + self.initialize_tensor_parallel_weights() + + def initialize_tensor_parallel_weights(self): + """ + Initialize weights for tensor parallel layers. + + This function performs the following steps: + 1. Retrieves the tensor parallel rank. + 2. Saves the current random state. + 3. Sets a new random seed based on the tensor parallel rank. + 4. Initializes weights for attention and MLP layers in each block. + 5. Restores the original random state. + + The use of different random seeds for each rank ensures + unique initializations across parallel processes. + """ + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + # Save the current random state + rng_state = torch.get_rng_state() + + # Set a new random seed based on the tensor parallel rank + torch.manual_seed(tp_rank) + + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["full_attn", "fa", "cross_attn", "ca"]: + # Initialize weights for attention layers + torch.nn.init.xavier_uniform_(layer.block.attn.to_q[0].weight) + torch.nn.init.xavier_uniform_(layer.block.attn.to_k[0].weight) + torch.nn.init.xavier_uniform_(layer.block.attn.to_v[0].weight) + torch.nn.init.xavier_uniform_(layer.block.attn.to_out[0].weight) + elif layer.block_type in ["mlp", "ff"]: + # Initialize weights for MLP layers + torch.nn.init.xavier_uniform_(layer.block.layer1.weight) + torch.nn.init.xavier_uniform_(layer.block.layer2.weight) + else: + raise ValueError(f"Unknown block type {layer.block_type}") + + # Restore the original random state + torch.set_rng_state(rng_state) + def build_decode_head(self): self.final_layer = FinalLayer( hidden_size=self.model_channels, @@ -240,20 +375,60 @@ def build_patch_embed(self): in_channels=in_channels, out_channels=model_channels, bias=False, + keep_spatio=True, + legacy_patch_emb=self.legacy_patch_emb, ) + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + if self.legacy_patch_emb: + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def build_additional_timestamp_embedder(self): + if self.additional_timestamp_channels: + self.additional_timestamp_embedder = nn.ModuleDict() + for cond_name, cond_emb_channels in self.additional_timestamp_channels.items(): + log.critical( + f"Building additional timestamp embedder for {cond_name} with {cond_emb_channels} channels" + ) + self.additional_timestamp_embedder[cond_name] = nn.Sequential( + SDXLTimesteps(cond_emb_channels), + SDXLTimestepEmbedding(cond_emb_channels, cond_emb_channels), + ) + + def prepare_additional_timestamp_embedder(self, **kwargs): + condition_concat = [] + + for cond_name, embedder in self.additional_timestamp_embedder.items(): + condition_concat.append(embedder(kwargs[cond_name])[0]) + embedding = torch.cat(condition_concat, dim=1) + if embedding.shape[1] < self.model_channels: + embedding = nn.functional.pad(embedding, (0, self.model_channels - embedding.shape[1])) + return embedding def build_pos_embed(self): - if self.pos_emb_cls == "rope3d": + if self.pos_emb_cls == "sincos": + cls_type = SinCosPosEmb + elif self.pos_emb_cls == "learnable": + cls_type = LearnableEmb3D + elif self.pos_emb_cls == "sincos_fps_aware": + cls_type = SinCosPosEmb_FPS_Aware + elif self.pos_emb_cls == "learnable_fps_aware": + cls_type = LearnableEmb3D_FPS_Aware + elif self.pos_emb_cls == "rope": + cls_type = VideoRopePositionEmb + elif self.pos_emb_cls == "rope3d": cls_type = VideoRopePosition3DEmb else: raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") - log.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + log.critical(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") kwargs = dict( model_channels=self.model_channels, len_h=self.max_img_h // self.patch_spatial, len_w=self.max_img_w // self.patch_spatial, len_t=self.max_frames // self.patch_temporal, + max_fps=self.max_fps, + min_fps=self.min_fps, is_learnable=self.pos_emb_learnable, interpolation=self.pos_emb_interpolation, head_dim=self.model_channels // self.num_heads, @@ -267,14 +442,20 @@ def build_pos_embed(self): if self.extra_per_block_abs_pos_emb: assert self.extra_per_block_abs_pos_emb_type in [ + "sincos", "learnable", ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio - self.extra_pos_embedder = LearnablePosEmbAxis( - **kwargs, - ) + if self.extra_per_block_abs_pos_emb_type == "sincos": + self.extra_pos_embedder = SinCosPosEmbAxis( + **kwargs, + ) + elif self.extra_per_block_abs_pos_emb_type == "learnable": + self.extra_pos_embedder = LearnablePosEmbAxis( + **kwargs, + ) def prepare_embedded_sequence( self, @@ -304,8 +485,8 @@ def prepare_embedded_sequence( - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using the `self.pos_embedder` with the shape [T, H, W]. - - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the - `self.pos_embedder` with the fps tensor. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` + with the fps tensor. - Otherwise, the positional embeddings are generated without considering fps. """ if self.concat_padding_mask: @@ -329,7 +510,6 @@ def prepare_embedded_sequence( x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] else: x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] - return x_B_T_H_W_D, None, extra_pos_emb def decoder_head( @@ -407,10 +587,30 @@ def forward_before_blocks( if scalar_feature is not None: raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() affline_emb_B_D = self.affline_norm(affline_emb_B_D) + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + if self.use_cross_attn_mask: crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] else: @@ -427,6 +627,24 @@ def forward_before_blocks( if crossattn_mask: crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + elif self.blocks["block0"].x_format == "BTHWD": x = x_B_T_H_W_D else: @@ -443,6 +661,199 @@ def forward_before_blocks( } return output + def forward_blocks_regular( + self, + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ): + features = [] + for name, block in self.blocks.items(): + assert ( + self.blocks["block0"].x_format == block.x_format + ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" + x = block( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + ) + + # Extract features + block_idx = int(name.split("block")[-1]) + if block_idx in feature_indices: + B, C, T, H, W = original_shape + H = H // self.patch_spatial + W = W // self.patch_spatial + T = T // self.patch_temporal + if self.sequence_parallel: + x_feat = gather_along_first_dim(x, parallel_state.get_tensor_model_parallel_group()) + x_B_T_H_W_D = rearrange(x_feat, "(T H W) 1 1 B D -> B T H W D", T=T, H=H, W=W) + else: + x_feat = x + if self.blocks["block0"].x_format == "THWBD": + x_B_T_H_W_D = rearrange(x_feat, "T H W B D -> B T H W D", T=T, H=H, W=W) + elif self.blocks["block0"].x_format == "BTHWD": + x_B_T_H_W_D = x_feat + else: + raise ValueError(f"Unknown x_format {self.blocks[-1].x_format}") + + features.append(x_B_T_H_W_D) + + if x_ctrl is not None and name in x_ctrl: + x = x + x_ctrl[name] + # If we have all of the features, we can exit early + if return_features_early and len(features) == len(feature_indices): + return features + + if self.blocks["block0"].x_format == "THWBD": + x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") + elif self.blocks["block0"].x_format == "BTHWD": + x_B_T_H_W_D = x + else: + raise ValueError(f"Unknown x_format {self.blocks[-1].x_format}") + + x_B_D_T_H_W = self.decoder_head( + x_B_T_H_W_D=x_B_T_H_W_D, + emb_B_D=affline_emb_B_D, + crossattn_emb=None, + origin_shape=original_shape, + crossattn_mask=None, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + + if len(feature_indices) == 0: + # no features requested, return only the model output + return x_B_D_T_H_W + else: + # score and features; score, features + return x_B_D_T_H_W, features + + def forward_blocks_memory_save( + self, + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ): + x_before_gate = 0 + x_skip = rearrange(x, "T H W B D -> (T H W) B D") + assert self.blocks["block0"].x_format == "THWBD" + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_per_block_pos_emb = rearrange(extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "T H W B D -> (T H W) B D") + else: + extra_per_block_pos_emb = None + gate_L_B_D = 1.0 + + features = [] + for name, block in self.blocks.items(): + gate_L_B_D, x_before_gate, x_skip = block( + x_before_gate, + x_skip, + gate_L_B_D, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + extra_per_block_pos_emb=extra_per_block_pos_emb, + ) + + # Extract features. + # Convert the block index in the memory save mode to the block index in the regular mode. + block_idx = int(name.split("block")[-1]) - 1 + if block_idx in feature_indices: + B, C, T_before_patchify, H_before_patchify, W_before_patchify = original_shape + H = H_before_patchify // self.patch_spatial + W = W_before_patchify // self.patch_spatial + T = T_before_patchify // self.patch_temporal + if self.sequence_parallel: + x_feat = gather_along_first_dim(x_skip, parallel_state.get_tensor_model_parallel_group()) + x_B_T_H_W_D = rearrange(x_feat, "(T H W) 1 1 B D -> B T H W D", T=T, H=H, W=W) + else: + x_feat = x_skip + x_B_T_H_W_D = rearrange(x_feat, "(T H W) B D -> B T H W D", T=T, H=H, W=W) + + features.append(x_B_T_H_W_D) + + new_name = f"block{block_idx}" + if x_ctrl is not None and new_name in x_ctrl: + x_ctrl_ = x_ctrl[new_name] + x_ctrl_ = rearrange(x_ctrl_, "T H W B D -> (T H W) B D") + x_skip = x_skip + x_ctrl_ + # If we have all of the features, we can exit early + if return_features_early and len(features) == len(feature_indices): + return features + + x_THW_B_D_before_gate = x_before_gate + x_THW_B_D_skip = x_skip + + B, C, T_before_patchify, H_before_patchify, W_before_patchify = original_shape + x_BT_HW_D_before_gate = rearrange( + x_THW_B_D_before_gate, + "(T H W) B D -> (B T) (H W) D", + T=T_before_patchify // self.patch_temporal, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + ) + x_BT_HW_D_skip = rearrange( + x_THW_B_D_skip, + "(T H W) B D -> (B T) (H W) D", + T=T_before_patchify // self.patch_temporal, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + ) + + x_BT_HW_D = self.final_layer.forward_with_memory_save( + x_BT_HW_D_before_gate=x_BT_HW_D_before_gate, + x_BT_HW_D_skip=x_BT_HW_D_skip, + gate_L_B_D=gate_L_B_D, + emb_B_D=affline_emb_B_D, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + + # This is to ensure x_BT_HW_D has the correct shape because + # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). + x_BT_HW_D = x_BT_HW_D.view( + B * T_before_patchify // self.patch_temporal, + H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, + -1, + ) + x_B_D_T_H_W = rearrange( + x_BT_HW_D, + "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + t=self.patch_temporal, + B=B, + ) + if len(feature_indices) == 0: + # no features requested, return only the model output + return x_B_D_T_H_W + else: + # score and features; score, features + return x_B_D_T_H_W, features + def forward( self, x: torch.Tensor, @@ -450,13 +861,16 @@ def forward( crossattn_emb: torch.Tensor, crossattn_mask: Optional[torch.Tensor] = None, fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, scalar_feature: Optional[torch.Tensor] = None, data_type: Optional[DataType] = DataType.VIDEO, + x_ctrl: Optional[dict] = None, latent_condition: Optional[torch.Tensor] = None, latent_condition_sigma: Optional[torch.Tensor] = None, + feature_indices: Optional[Container[int]] = None, + return_features_early: bool = False, condition_video_augment_sigma: Optional[torch.Tensor] = None, - x_ctrl: Optional[dict] = None, **kwargs, ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: """ @@ -465,10 +879,19 @@ def forward( timesteps: (B, ) tensor of timesteps crossattn_emb: (B, N, D) tensor of cross-attention embeddings crossattn_mask: (B, N) tensor of cross-attention masks - condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to - augment condition input, the lvg model will condition on the condition_video_augment_sigma value; + feature_indices: A set of feature indices (a set of integers) decides which blocks + to extract features from. If the set is non-empty, then features will be returned. + By default, feature_indices=None means extract no features. + return_features_early: If true, the forward pass returns the features once the set is complete. + This means the forward pass will not finish completely and no final output is returned. + condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to augment condition input, the lvg model will condition on the condition_video_augment_sigma value; we need forward_before_blocks pass to the forward_before_blocks function. """ + if feature_indices is None: + feature_indices = {} + if return_features_early and len(feature_indices) == 0: + # Exit immediately if user requested this. + return [] inputs = self.forward_before_blocks( x=x, @@ -476,6 +899,7 @@ def forward( crossattn_emb=crossattn_emb, crossattn_mask=crossattn_mask, fps=fps, + image_size=image_size, padding_mask=padding_mask, scalar_feature=scalar_feature, data_type=data_type, @@ -499,35 +923,38 @@ def forward( x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" - for name, block in self.blocks.items(): - assert ( - self.blocks["block0"].x_format == block.x_format - ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" - - x = block( + if self.use_memory_save: + return self.forward_blocks_memory_save( x, affline_emb_B_D, crossattn_emb, crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - adaln_lora_B_3D=adaln_lora_B_3D, - extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, ) - if x_ctrl is not None and name in x_ctrl: - x = x + x_ctrl[name] - - x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") - x_B_D_T_H_W = self.decoder_head( - x_B_T_H_W_D=x_B_T_H_W_D, - emb_B_D=affline_emb_B_D, - crossattn_emb=None, - origin_shape=original_shape, - crossattn_mask=None, - adaln_lora_B_3D=adaln_lora_B_3D, + return self.forward_blocks_regular( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, ) - return x_B_D_T_H_W + @property + def fsdp_wrap_block_cls(self): + return DITBuildingBlock def enable_context_parallel(self, cp_group: ProcessGroup): cp_ranks = get_process_group_ranks(cp_group) @@ -574,6 +1001,29 @@ def disable_context_parallel(self): log.debug("[CP] Disable context parallelism.") + def enable_sequence_parallel(self): + self._set_sequence_parallel(True) + + def disable_sequence_parallel(self): + self._set_sequence_parallel(False) + + def _set_sequence_parallel(self, status: bool): + self.sequence_parallel = status + self.final_layer.sequence_parallel = status + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["full_attn", "fa", "cross_attn", "ca"]: + layer.block.attn.to_q[0].sequence_parallel = status + layer.block.attn.to_k[0].sequence_parallel = status + layer.block.attn.to_v[0].sequence_parallel = status + layer.block.attn.to_out[0].sequence_parallel = status + layer.block.attn.attn_op.sequence_parallel = status + elif layer.block_type in ["mlp", "ff"]: + layer.block.layer1.sequence_parallel = status + layer.block.layer2.sequence_parallel = status + else: + raise ValueError(f"Unknown block type {layer.block_type}") + @property def is_context_parallel_enabled(self): return self.cp_group is not None diff --git a/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py b/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py index 097eaf83..98557d96 100644 --- a/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py +++ b/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py @@ -19,10 +19,11 @@ from typing import List, Optional, Tuple +import numpy as np import torch from einops import rearrange -# from megatron.core import parallel_state +from megatron.core import parallel_state from torch import nn from torchvision import transforms @@ -30,6 +31,7 @@ from cosmos_transfer1.diffusion.module.blocks import PatchEmbed, zero_module from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp from cosmos_transfer1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT as GeneralDIT +from cosmos_transfer1.diffusion.training.tensor_parallel import scatter_along_first_dim class GeneralDITEncoder(GeneralDIT): @@ -60,7 +62,7 @@ def __init__(self, *args, **kwargs): input_hint_block += [nn.Linear(hint_nf[i], hint_nf[i + 1]), nonlinearity] self.input_hint_block = nn.Sequential(*input_hint_block) # Initialize weights - self.initialize_weights() + self.init_weights() self.zero_blocks = nn.ModuleDict() for idx in range(num_blocks): if layer_mask[idx]: @@ -68,6 +70,11 @@ def __init__(self, *args, **kwargs): self.zero_blocks[f"block{idx}"] = zero_module(nn.Linear(model_channels, model_channels)) self.input_hint_block.append(zero_module(nn.Linear(hint_nf[-1], model_channels))) + def _set_sequence_parallel(self, status: bool): + self.zero_blocks.sequence_parallel = status + self.input_hint_block.sequence_parallel = status + super()._set_sequence_parallel(status) + def build_hint_patch_embed(self): concat_padding_mask, in_channels, patch_spatial, patch_temporal, model_channels = ( self.concat_padding_mask, @@ -83,8 +90,15 @@ def build_hint_patch_embed(self): in_channels=in_channels, out_channels=model_channels, bias=False, + keep_spatio=True, + legacy_patch_emb=self.legacy_patch_emb, ) + if self.legacy_patch_emb: + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.x_embedder2.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + def prepare_hint_embedded_sequence( self, x_B_C_T_H_W: torch.Tensor, fps: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -125,7 +139,18 @@ def encode_hint( ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." hint_B_T_H_W_D, _ = self.prepare_hint_embedded_sequence(hint, fps=fps, padding_mask=padding_mask) - hint = rearrange(hint_B_T_H_W_D, "B T H W D -> T H W B D") + + if self.blocks["block0"].x_format == "THWBD": + hint = rearrange(hint_B_T_H_W_D, "B T H W D -> T H W B D") + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + T, H, W, B, D = hint.shape + hint = hint.view(T * H * W, 1, 1, B, -1) + hint = scatter_along_first_dim(hint, tp_group) + elif self.blocks["block0"].x_format == "BTHWD": + hint = hint_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") guided_hint = self.input_hint_block(hint) return guided_hint @@ -137,6 +162,7 @@ def forward( crossattn_emb: torch.Tensor, crossattn_mask: Optional[torch.Tensor] = None, fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, scalar_feature: Optional[torch.Tensor] = None, data_type: Optional[DataType] = DataType.VIDEO, @@ -169,6 +195,7 @@ def forward( crossattn_emb=crossattn_emb_input, crossattn_mask=crossattn_mask_input, fps=fps, + image_size=image_size, padding_mask=padding_mask, scalar_feature=scalar_feature, data_type=data_type, @@ -201,6 +228,7 @@ def forward( ) input_list = [x, condition_video_input_mask] x = torch.cat(input_list, dim=1) + elif data_type == DataType.IMAGE: # For image, we dont have condition_video_input_mask, or condition_video_pose # We need to add the extra channel for video condition mask @@ -217,19 +245,41 @@ def forward( else: crossattn_mask = None - crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") - if crossattn_mask: - crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + if self.blocks["block0"].x_format == "THWBD": + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") outs = {} + # If also training base model, sometimes drop the controlnet branch to only train base branch. + # This is to prevent the network become dependent on controlnet branch and make control weight useless. + is_training = torch.is_grad_enabled() + is_training_base_model = any(p.requires_grad for p in base_model.parameters()) + if is_training and is_training_base_model: + coin_flip = torch.rand(B).to(x.device) > self.dropout_ctrl_branch # prob for only training base model + if self.blocks["block0"].x_format == "THWBD": + coin_flip = coin_flip[None, None, None, :, None] + elif self.blocks["block0"].x_format == "BTHWD": + coin_flip = coin_flip[:, None, None, None, None] + else: + coin_flip = 1 + num_control_blocks = self.layer_mask.index(True) - num_layers_to_use = num_control_blocks + if self.random_drop_control_blocks: + if is_training: # Use a random number of layers during training. + num_layers_to_use = np.random.randint(num_control_blocks) + 1 + elif num_layers_to_use == -1: # Evaluate using all the layers. + num_layers_to_use = num_control_blocks + else: # Use the specified number of layers during inference. + pass + else: # Use all of the layers. + num_layers_to_use = num_control_blocks control_gate_per_layer = [i < num_layers_to_use for i in range(num_control_blocks)] if isinstance(control_weight, torch.Tensor): if control_weight.ndim == 0: # Single scalar tensor - control_weight = [float(control_weight)] + control_weight = [float(control_weight)] * len(guided_hints) elif control_weight.ndim == 1: # List of scalar weights control_weight = [float(w) for w in control_weight] else: # Spatial-temporal weight maps @@ -237,6 +287,7 @@ def forward( else: control_weight = [control_weight] * len(guided_hints) + # max_norm = {} x_before_blocks = x.clone() for i, guided_hint in enumerate(guided_hints): x = x_before_blocks @@ -266,6 +317,19 @@ def forward( if scalar_feature is not None: raise NotImplementedError("Scalar feature is not implemented yet.") + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() affline_emb_B_D = affline_norm(affline_emb_B_D) @@ -273,11 +337,34 @@ def forward( self.affline_scale_log_info = affline_scale_log_info self.affline_emb = affline_emb_B_D - x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" - ) + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") for idx, (name, block) in enumerate(blocks.items()): assert ( @@ -298,36 +385,45 @@ def forward( gate = control_gate_per_layer[idx] if isinstance(control_weight[i], (float, int)) or control_weight[i].ndim < 2: - hint_val = zero_blocks[name](x) * control_weight[i] * gate + hint_val = zero_blocks[name](x) * control_weight[i] * coin_flip * gate else: # Spatial-temporal weights [num_controls, B, 1, T, H, W] control_feat = zero_blocks[name](x) - T, H, W, _, _ = control_feat.shape # Get current feature dimensions - weight_map = control_weight[i] # [B, 1, T, H, W] - if weight_map.shape[2:5] != (T, H, W): - assert weight_map.shape[2] == 8 * (T - 1) + 1 - weight_map_i = [ - torch.nn.functional.interpolate( - weight_map[:, :, :1, :, :], - size=(1, H, W), - mode="trilinear", - align_corners=False, - ) - ] - for wi in range(1, weight_map.shape[2], 8): - weight_map_i += [ + if self.blocks["block0"].x_format == "THWBD": + weight_map = control_weight[i] # [B, 1, T, H, W] + + if weight_map.shape[2:5] != (T, H, W): + assert weight_map.shape[2] == 8 * (T - 1) + 1 + weight_map_i = [ torch.nn.functional.interpolate( - weight_map[:, :, wi : wi + 8], + weight_map[:, :, :1, :, :], size=(1, H, W), mode="trilinear", align_corners=False, ) ] - weight_map = torch.cat(weight_map_i, dim=2) - # Reshape to match THWBD format - weight_map = weight_map.permute(2, 3, 4, 0, 1) # [T, H, W, B, 1] - hint_val = control_feat * weight_map * gate + for wi in range(1, weight_map.shape[2], 8): + weight_map_i += [ + torch.nn.functional.interpolate( + weight_map[:, :, wi : wi + 8], + size=(1, H, W), + mode="trilinear", + align_corners=False, + ) + ] + weight_map = torch.cat(weight_map_i, dim=2) + + # Reshape to match THWBD format + weight_map = weight_map.permute(2, 3, 4, 0, 1) # [T, H, W, B, 1] + weight_map = weight_map.view(T * H * W, 1, 1, B, 1) + if self.sequence_parallel: + weight_map = scatter_along_first_dim(weight_map, tp_group) + + else: # BTHWD format + raise NotImplementedError("BTHWD format for weight map is not implemented yet.") + hint_val = control_feat * weight_map * coin_flip * gate + if name not in outs: outs[name] = hint_val else: @@ -339,6 +435,7 @@ def forward( crossattn_emb=crossattn_emb_input, crossattn_mask=crossattn_mask_input, fps=fps, + image_size=image_size, padding_mask=padding_mask, scalar_feature=scalar_feature, data_type=data_type, diff --git a/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py b/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py index 266de46b..9cb2fe90 100644 --- a/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py +++ b/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py @@ -21,9 +21,10 @@ from torch import nn from cosmos_transfer1.diffusion.conditioner import DataType -from cosmos_transfer1.diffusion.module.blocks import TimestepEmbedding, Timesteps +from cosmos_transfer1.diffusion.module.blocks import SDXLTimesteps, SDXLTimestepEmbedding from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp from cosmos_transfer1.diffusion.networks.general_dit import GeneralDIT +from cosmos_transfer1.diffusion.training.tensor_parallel import scatter_along_first_dim from cosmos_transfer1.utils import log @@ -33,18 +34,18 @@ def __init__(self, *args, in_channels=16 + 1, add_augment_sigma_embedding=False, # extra channel for video condition mask super().__init__(*args, in_channels=in_channels, **kwargs) - log.debug(f"VideoExtendGeneralDIT in_channels: {in_channels}") + log.info(f"VideoExtendGeneralDIT in_channels: {in_channels}") def build_additional_timestamp_embedder(self): super().build_additional_timestamp_embedder() if self.add_augment_sigma_embedding: log.info("Adding augment sigma embedding") self.augment_sigma_embedder = nn.Sequential( - Timesteps(self.model_channels), - TimestepEmbedding(self.model_channels, self.model_channels, use_adaln_lora=self.use_adaln_lora), + SDXLTimesteps(self.model_channels), + SDXLTimestepEmbedding(self.model_channels, self.model_channels, use_adaln_lora=self.use_adaln_lora), ) - def initialize_weights(self): + def init_weights(self): if self.add_augment_sigma_embedding: # Initialize timestep embedding for augment sigma nn.init.normal_(self.augment_sigma_embedder[1].linear_1.weight, std=0.02) @@ -54,7 +55,7 @@ def initialize_weights(self): if self.augment_sigma_embedder[1].linear_2.bias is not None: nn.init.constant_(self.augment_sigma_embedder[1].linear_2.bias, 0) - super().initialize_weights() # Call this last since it wil call TP weight init + super().init_weights() # Call this last since it wil call TP weight init def forward( self, @@ -63,6 +64,7 @@ def forward( crossattn_emb: torch.Tensor, crossattn_mask: Optional[torch.Tensor] = None, fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, scalar_feature: Optional[torch.Tensor] = None, data_type: Optional[DataType] = DataType.VIDEO, @@ -70,49 +72,56 @@ def forward( condition_video_indicator: Optional[torch.Tensor] = None, condition_video_input_mask: Optional[torch.Tensor] = None, condition_video_augment_sigma: Optional[torch.Tensor] = None, + condition_video_pose: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: - """Forward pass of the video-conditioned DIT model. - - Args: - x: Input tensor of shape (B, C, T, H, W) - timesteps: Timestep tensor of shape (B,) - crossattn_emb: Cross attention embeddings of shape (B, N, D) - crossattn_mask: Optional cross attention mask of shape (B, N) - fps: Optional frames per second tensor - padding_mask: Optional padding mask tensor - scalar_feature: Optional scalar features tensor - data_type: Type of data being processed (default: DataType.VIDEO) - video_cond_bool: Optional video conditioning boolean tensor - condition_video_indicator: Optional video condition indicator tensor - condition_video_input_mask: Required mask tensor for video data type - condition_video_augment_sigma: Optional sigma values for conditional input augmentation - **kwargs: Additional keyword arguments - - Returns: - torch.Tensor: Output tensor + """Args: + condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation + condition_video_pose: (B, 1, T, H, W) tensor of pose condition """ B, C, T, H, W = x.shape if data_type == DataType.VIDEO: - assert condition_video_input_mask is not None, "condition_video_input_mask is required for video data type" + assert ( + condition_video_input_mask is not None + ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" + if self.cp_group is not None: + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=2, cp_group=self.cp_group + ) + condition_video_indicator = split_inputs_cp( + condition_video_indicator, seq_dim=2, cp_group=self.cp_group + ) + if condition_video_pose is not None: + condition_video_pose = split_inputs_cp(condition_video_pose, seq_dim=2, cp_group=self.cp_group) - if parallel_state.is_initialized(): - cp_group = parallel_state.get_context_parallel_group() - condition_video_input_mask = split_inputs_cp(condition_video_input_mask, seq_dim=2, cp_group=cp_group) - condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) input_list = [x, condition_video_input_mask] + if condition_video_pose is not None: + if condition_video_pose.shape[2] > T: + log.warning( + f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" + ) + condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() + input_list.append(condition_video_pose) x = torch.cat( input_list, dim=1, ) + if data_type == DataType.IMAGE: + # For image, we dont have condition_video_input_mask, or condition_video_pose + # We need to add the extra channel for video condition mask + padding_channels = self.in_channels - x.shape[1] + x = torch.cat([x, torch.zeros((B, padding_channels, T, H, W), dtype=x.dtype, device=x.device)], dim=1) + else: + assert x.shape[1] == self.in_channels, f"Expected {self.in_channels} channels, got {x.shape[1]}" return super().forward( x=x, timesteps=timesteps, crossattn_emb=crossattn_emb, crossattn_mask=crossattn_mask, fps=fps, + image_size=image_size, padding_mask=padding_mask, scalar_feature=scalar_feature, data_type=data_type, @@ -127,6 +136,7 @@ def forward_before_blocks( crossattn_emb: torch.Tensor, crossattn_mask: Optional[torch.Tensor] = None, fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, scalar_feature: Optional[torch.Tensor] = None, data_type: Optional[DataType] = DataType.VIDEO, @@ -165,33 +175,76 @@ def forward_before_blocks( if scalar_feature is not None: raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() if self.add_augment_sigma_embedding: if condition_video_augment_sigma is None: # Handling image case - # Note: for video case, when there is not condition frames, we also set it as zero, see extend_model augment_conditional_latent_frames function + # Note: for video case, when there is not condition frames, we also set it as zero, see DiffusionV2WModel augment_conditional_latent_frames function assert data_type == DataType.IMAGE, "condition_video_augment_sigma is required for video data type" condition_video_augment_sigma = torch.zeros_like(timesteps.flatten()) - affline_augment_sigma_emb_B_D, _ = self.augment_sigma_embedder(condition_video_augment_sigma.flatten()) + affline_augment_sigma_emb_B_D, adaln_lora_sigma_emb_B_3D = self.augment_sigma_embedder( + condition_video_augment_sigma.flatten() + ) affline_emb_B_D = affline_emb_B_D + affline_augment_sigma_emb_B_D affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() affline_emb_B_D = self.affline_norm(affline_emb_B_D) + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + if self.use_cross_attn_mask: crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] else: crossattn_mask = None - x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" - ) - crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") - if crossattn_mask: - crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") output = { "x": x, "affline_emb_B_D": affline_emb_B_D, diff --git a/cosmos_transfer1/diffusion/training/callbacks/every_n.py b/cosmos_transfer1/diffusion/training/callbacks/every_n.py new file mode 100644 index 00000000..bf6d5e6e --- /dev/null +++ b/cosmos_transfer1/diffusion/training/callbacks/every_n.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Optional + +import torch + +from cosmos_transfer1.utils import distributed, log +from cosmos_transfer1.utils.callback import Callback +from cosmos_transfer1.utils.model import Model +from cosmos_transfer1.utils.trainer import Trainer + + +class EveryN(Callback): + def __init__( + self, + every_n: Optional[int] = None, + step_size: int = 1, + barrier_after_run: bool = True, + run_at_start: bool = False, + ) -> None: + """Constructor for `EveryN`. + + Args: + every_n (int): Frequency with which callback is run during training. + step_size (int): Size of iteration step count. Default 1. + barrier_after_run (bool): Whether to have a distributed barrier after each execution. Default True, to avoid timeouts. + run_at_start (bool): Whether to run at the beginning of training. Default False. + """ + self.every_n = every_n + if self.every_n == 0: + log.warning( + f"every_n is set to 0. Callback {self.__class__.__name__} will be invoked only once in the beginning of the training. Calls happens on_training_step_end will be skipped." + ) + + self.step_size = step_size + self.barrier_after_run = barrier_after_run + self.run_at_start = run_at_start + + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + # every_n = 0 is a special case which means every_n_impl will be called only once in the beginning of the training + if self.every_n != 0: + trainer = self.trainer + global_step = iteration // self.step_size + should_run = (iteration == 1 and self.run_at_start) or ( + global_step % self.every_n == 0 + ) # (self.every_n - 1) + if should_run: + log.debug(f"Callback {self.__class__.__name__} fired on train_batch_end step {global_step}") + self.every_n_impl(trainer, model, data_batch, output_batch, loss, iteration) + log.debug(f"Callback {self.__class__.__name__} finished on train_batch_end step {global_step}") + # add necessary barrier to avoid timeout + if self.barrier_after_run: + distributed.barrier() + + @abstractmethod + def every_n_impl( + self, + trainer: Trainer, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int, + ) -> None: + ... diff --git a/cosmos_transfer1/diffusion/training/callbacks/grad_clip.py b/cosmos_transfer1/diffusion/training/callbacks/grad_clip.py new file mode 100644 index 00000000..a3ccd88d --- /dev/null +++ b/cosmos_transfer1/diffusion/training/callbacks/grad_clip.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Tuple + +import torch +from megatron.core import parallel_state +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from cosmos_transfer1.utils import distributed +from cosmos_transfer1.utils.callback import GradClip as GradClipImage +from cosmos_transfer1.utils.callback import _fused_nan_to_num +from cosmos_transfer1.utils.model import Model + + +@dataclass +class _MagnitudeRecord: + state: float = 0 + iter_count: int = 0 + + def reset(self) -> None: + self.state = 0 + self.iter_count = 0 + + def update(self, cur_state: torch.Tensor) -> None: + self.state += cur_state + self.iter_count += 1 + + def get_stat(self) -> Tuple[float, float]: + if self.iter_count > 0: + avg_state = self.state / self.iter_count + avg_state = avg_state.item() + else: + avg_state = 0 + self.reset() + return avg_state + + +class GradClip(GradClipImage): + ''' + adds support for TP + ''' + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.img_mag_log = _MagnitudeRecord() + self.video_mag_log = _MagnitudeRecord() + self._cur_state = None + + def on_training_step_start(self, model: Model, data_batch: dict[str, torch.Tensor], iteration: int = 0) -> None: + if model.is_image_batch(data_batch): + self._cur_state = self.img_mag_log + else: + self._cur_state = self.video_mag_log + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + del optimizer, scheduler + if isinstance(model_ddp, distributed.DistributedDataParallel): + model = model_ddp.module + else: + model = model_ddp + params = [] + if self.model_key is not None: + items = self.model_key.split(".") + for item in items: + model = getattr(model, item) + if self.force_finite: + for param in model.parameters(): + if param.grad is not None: + params.append(param.grad) + # torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) + _fused_nan_to_num(params) + + if isinstance(model, FSDP) and self.fsdp_enabled: + total_norm = model.clip_grad_norm_(self.clip_norm) + else: + if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: + total_norm = model_ddp.module.clip_grad_norm_(self.clip_norm) + else: + total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True) + + self._cur_state.update(total_norm) diff --git a/cosmos_transfer1/diffusion/training/callbacks/iter_speed.py b/cosmos_transfer1/diffusion/training/callbacks/iter_speed.py new file mode 100644 index 00000000..14ce02dc --- /dev/null +++ b/cosmos_transfer1/diffusion/training/callbacks/iter_speed.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +from torch import Tensor + +from cosmos_transfer1.diffusion.training.callbacks.every_n import EveryN +from cosmos_transfer1.utils import log +from cosmos_transfer1.utils.distributed import rank0_only +from cosmos_transfer1.utils.model import Model +from cosmos_transfer1.utils.trainer import Trainer + + +class IterSpeed(EveryN): + """ + Args: + hit_thres (int): Number of iterations to wait before logging. + """ + + def __init__(self, *args, hit_thres: int = 5, **kwargs): + super().__init__(*args, **kwargs) + self.time = None + self.hit_counter = 0 + self.hit_thres = hit_thres + self.name = self.__class__.__name__ + self.last_hit_time = time.time() + + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + if self.hit_counter < self.hit_thres: + log.info( + f"Iteration {iteration}: " + f"Hit counter: {self.hit_counter + 1}/{self.hit_thres} | " + f"Loss: {loss.item():.4f} | " + f"Time: {time.time() - self.last_hit_time:.2f}s" + ) + self.hit_counter += 1 + self.last_hit_time = time.time() + #! useful for large scale training and avoid oom crash in the first two iterations!!! + torch.cuda.synchronize() + return + super().on_training_step_end(model, data_batch, output_batch, loss, iteration) + + @rank0_only + def every_n_impl( + self, + trainer: Trainer, + model: Model, + data_batch: dict[str, Tensor], + output_batch: dict[str, Tensor], + loss: Tensor, + iteration: int, + ) -> None: + if self.time is None: + self.time = time.time() + return + cur_time = time.time() + iter_speed = (cur_time - self.time) / self.every_n / self.step_size + + log.info(f"{iteration} : iter_speed {iter_speed:.2f} seconds per iteration | Loss: {loss.item():.4f}") + + self.time = cur_time diff --git a/cosmos_transfer1/diffusion/training/callbacks/low_precision.py b/cosmos_transfer1/diffusion/training/callbacks/low_precision.py new file mode 100644 index 00000000..afd7b62b --- /dev/null +++ b/cosmos_transfer1/diffusion/training/callbacks/low_precision.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from cosmos_transfer1.utils.trainer import Trainer +from cosmos_transfer1.utils.callback import LowPrecisionCallback as BaseCallback +from cosmos_transfer1.utils.config import Config +from cosmos_transfer1.utils.model import Model + + +class LowPrecisionCallback(BaseCallback): + """ + Config with non-primitive type makes it difficult to override the option. + The callback gets precision from model.precision instead. + """ + + def __init__(self, config: Config, trainer: Trainer, update_iter: int): + self.config = config + self.trainer = trainer + self.update_iter = update_iter + + def on_train_start(self, model: Model, iteration: int = 0) -> None: + assert model.precision in [ + torch.bfloat16, + torch.float16, + torch.half, + ], "LowPrecisionCallback must use a low precision dtype." + self.precision_type = model.precision diff --git a/cosmos_transfer1/diffusion/training/datasets/data_sources/item_dataset.py b/cosmos_transfer1/diffusion/training/datasets/data_sources/item_dataset.py new file mode 100644 index 00000000..0a4dcd91 --- /dev/null +++ b/cosmos_transfer1/diffusion/training/datasets/data_sources/item_dataset.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses + + +@dataclasses.dataclass +class ItemDatasetConfig: + path: str + length: int diff --git a/cosmos_transfer1/diffusion/training/datasets/dataset_utils.py b/cosmos_transfer1/diffusion/training/datasets/dataset_utils.py new file mode 100644 index 00000000..963e4c9d --- /dev/null +++ b/cosmos_transfer1/diffusion/training/datasets/dataset_utils.py @@ -0,0 +1,311 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from: +https://github.com/bytedance/IRASim/blob/main/dataset/dataset_util.py +""" + +import base64 +import math +import os +from io import BytesIO + +import numpy as np +import torch +import torch.distributed as dist +import torchvision.transforms.functional as F +from PIL import Image + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def b64_2_img(data: str): + image_b64 = base64.b64decode(data) + img = Image.open(BytesIO(image_b64)).convert("RGB") + return img + + +def get_continuous_action(d_acts, c_act_max, c_act_min, n_bins): + c_act_max = c_act_max.to(d_acts.device) + c_act_min = c_act_min.to(d_acts.device) + c_acts = d_acts / (n_bins - 1) * (c_act_max - c_act_min) + c_act_min + return c_acts + + +def alpha2rotm(a): + """Alpha euler angle to rotation matrix.""" + rotm = np.array([[1, 0, 0], [0, np.cos(a), -np.sin(a)], [0, np.sin(a), np.cos(a)]]) + return rotm + + +def beta2rotm(b): + """Beta euler angle to rotation matrix.""" + rotm = np.array([[np.cos(b), 0, np.sin(b)], [0, 1, 0], [-np.sin(b), 0, np.cos(b)]]) + return rotm + + +def gamma2rotm(c): + """Gamma euler angle to rotation matrix.""" + rotm = np.array([[np.cos(c), -np.sin(c), 0], [np.sin(c), np.cos(c), 0], [0, 0, 1]]) + return rotm + + +def euler2rotm(euler_angles): + """Euler angle (ZYX) to rotation matrix.""" + alpha = euler_angles[0] + beta = euler_angles[1] + gamma = euler_angles[2] + + rotm_a = alpha2rotm(alpha) + rotm_b = beta2rotm(beta) + rotm_c = gamma2rotm(gamma) + + rotm = rotm_c @ rotm_b @ rotm_a + + return rotm + + +def isRotm(R): + # Checks if a matrix is a valid rotation matrix. + # Forked from Andy Zeng + Rt = np.transpose(R) + shouldBeIdentity = np.dot(Rt, R) + I = np.identity(3, dtype=R.dtype) + n = np.linalg.norm(I - shouldBeIdentity) + return n < 1e-6 + + +def rotm2euler(R): + # Forked from: https://learnopencv.com/rotation-matrix-to-euler-angles/ + # R = Rz * Ry * Rx + assert isRotm(R) + sy = math.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0]) + singular = sy < 1e-6 + + if not singular: + x = math.atan2(R[2, 1], R[2, 2]) + y = math.atan2(-R[2, 0], sy) + z = math.atan2(R[1, 0], R[0, 0]) + else: + x = math.atan2(-R[1, 2], R[1, 1]) + y = math.atan2(-R[2, 0], sy) + z = 0 + + # (-pi , pi] + while x > np.pi: + x -= 2 * np.pi + while x <= -np.pi: + x += 2 * np.pi + while y > np.pi: + y -= 2 * np.pi + while y <= -np.pi: + y += 2 * np.pi + while z > np.pi: + z -= 2 * np.pi + while z <= -np.pi: + z += 2 * np.pi + return np.array([x, y, z]) + + +def get_converted_fp32_paths(deepspeed_ckpt_path): + deepspeed_ckpt_path = deepspeed_ckpt_path.rstrip("/") + ckpt_dir = os.path.dirname(deepspeed_ckpt_path) + ckpt_name = os.path.basename(deepspeed_ckpt_path) + fp32_ckpt_name = f"{ckpt_name}.fp32.pt" + converted_path = os.path.join(ckpt_dir, fp32_ckpt_name) + return converted_path + + +def quat2rotm(quat): + """Quaternion to rotation matrix. + + Args: + quat (4, numpy array): quaternion x, y, z, w + Returns: + rotm (3x3 numpy array): rotation matrix + """ + w = quat[3] + x = quat[0] + y = quat[1] + z = quat[2] + + s = w * w + x * x + y * y + z * z + + rotm = np.array( + [ + [1 - 2 * (y * y + z * z) / s, 2 * (x * y - z * w) / s, 2 * (x * z + y * w) / s], + [2 * (x * y + z * w) / s, 1 - 2 * (x * x + z * z) / s, 2 * (y * z - x * w) / s], + [2 * (x * z - y * w) / s, 2 * (y * z + x * w) / s, 1 - 2 * (x * x + y * y) / s], + ] + ) + + return rotm + + +class Resize_Preprocess: + def __init__(self, size): + """ + Initialize the preprocessing class with the target size. + Args: + size (tuple): The target height and width as a tuple (height, width). + """ + self.size = size + + def __call__(self, video_frames): + """ + Apply the transformation to each frame in the video. + Args: + video_frames (torch.Tensor): A tensor representing a batch of video frames. + Returns: + torch.Tensor: The transformed video frames. + """ + # Resize each frame in the video + resized_frames = torch.stack([F.resize(frame, self.size, antialias=True) for frame in video_frames]) + return resized_frames + + +class Preprocess: + def __init__(self, size): + self.size = size + + def __call__(self, clip): + clip = Preprocess.resize_scale(clip, self.size[0], self.size[1], interpolation_mode="bilinear") + return clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + @staticmethod + def resize_scale(clip, target_height, target_width, interpolation_mode): + target_ratio = target_height / target_width + H = clip.size(-2) + W = clip.size(-1) + clip_ratio = H / W + if clip_ratio > target_ratio: + scale_ = target_width / W + else: + scale_ = target_height / H + return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) + + +class ToTensorVideo: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + return to_tensor(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) + # return clip.float().permute(3, 0, 1, 2) / 255.0 + return clip.float() / 255.0 + + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True diff --git a/cosmos_transfer1/diffusion/training/datasets/dataset_video.py b/cosmos_transfer1/diffusion/training/datasets/dataset_video.py new file mode 100644 index 00000000..6635ace6 --- /dev/null +++ b/cosmos_transfer1/diffusion/training/datasets/dataset_video.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run this command to interactively debug: +PYTHONPATH=. python cosmos_transfer1/diffusion/training/datasets/dataset_gear.py + +Adapted from: +https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py +""" + +import os +import pickle +import traceback +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np +import torch +from decord import VideoReader, cpu +from torch.utils.data import Dataset +from torchvision import transforms as T +from tqdm import tqdm + +from cosmos_transfer1.diffusion.training.datasets.dataset_utils import Resize_Preprocess, ToTensorVideo + + +class Dataset(Dataset): + def __init__( + self, + dataset_dir, + sequence_interval, + num_frames, + video_size, + start_frame_interval=1, + ): + """Dataset class for loading image-text-to-video generation data. + + Args: + dataset_dir (str): Base path to the dataset directory + sequence_interval (int): Interval between sampled frames in a sequence + num_frames (int): Number of frames to load per sequence + video_size (list): Target size [H,W] for video frames + + Returns dict with: + - video: RGB frames tensor [T,C,H,W] + - video_name: Dict with episode/frame metadata + """ + + super().__init__() + self.dataset_dir = dataset_dir + self.start_frame_interval = start_frame_interval + self.sequence_interval = sequence_interval + self.sequence_length = num_frames + + video_dir = os.path.join(self.dataset_dir, "videos") + self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")] + # print(f"{len(self.video_paths)} trajectories in total") + print(f"{len(self.video_paths)} videos in total") + + # self.t5_dir = os.path.join(self.dataset_dir, "labels") + self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl") + self.samples = self._init_samples(self.video_paths) + self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0])) + print(f"{len(self.samples)} samples in total") + self.wrong_number = 0 + self.preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) + + def __str__(self): + return f"{len(self.video_paths)} samples from {self.dataset_dir}" + + def _init_samples(self, video_paths): + samples = [] + with ThreadPoolExecutor(32) as executor: + future_to_video_path = { + executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths + } + for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)): + samples.extend(future.result()) + return samples + + def _load_and_process_video_path(self, video_path): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + n_frames = len(vr) + + samples = [] + for frame_i in range(0, n_frames, self.start_frame_interval): + sample = dict() + sample["video_path"] = video_path + sample["t5_embedding_path"] = os.path.join( + # self.t5_dir, os.path.basename(video_path).replace(".mp4", ".npy") + self.t5_dir, + os.path.basename(video_path).replace(".mp4", ".pickle"), + ) + sample["frame_ids"] = [] + curr_frame_i = frame_i + while True: + if curr_frame_i > (n_frames - 1): + break + sample["frame_ids"].append(curr_frame_i) + if len(sample["frame_ids"]) == self.sequence_length: + break + curr_frame_i += self.sequence_interval + # make sure there are sequence_length number of frames + if len(sample["frame_ids"]) == self.sequence_length: + samples.append(sample) + return samples + + def __len__(self): + return len(self.samples) + + def _load_video(self, video_path, frame_ids): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + assert (np.array(frame_ids) < len(vr)).all() + assert (np.array(frame_ids) >= 0).all() + vr.seek(0) + frame_data = vr.get_batch(frame_ids).asnumpy() + try: + fps = vr.get_avg_fps() + except Exception: # failed to read FPS + fps = 24 + return frame_data, fps + + def _get_frames(self, video_path, frame_ids): + frames, fps = self._load_video(video_path, frame_ids) + frames = frames.astype(np.uint8) + frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w) + frames = self.preprocess(frames) + frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) + return frames, fps + + def __getitem__(self, index): + try: + sample = self.samples[index] + video_path = sample["video_path"] + frame_ids = sample["frame_ids"] + + data = dict() + + video, fps = self._get_frames(video_path, frame_ids) + video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + data["video"] = video + data["video_name"] = { + "video_path": video_path, + "t5_embedding_path": sample["t5_embedding_path"], + "start_frame_id": str(frame_ids[0]), + } + + # Just add these to fit the interface + # t5_embedding = np.load(sample["t5_embedding_path"])[0] + with open(sample["t5_embedding_path"], "rb") as f: + t5_embedding = pickle.load(f)[0] + + data["t5_text_embeddings"] = torch.from_numpy(t5_embedding).cuda() + data["t5_text_mask"] = torch.ones(512, dtype=torch.int64).cuda() + data["fps"] = fps + data["image_size"] = torch.tensor([704, 1280, 704, 1280]).cuda() + data["num_frames"] = self.sequence_length + data["padding_mask"] = torch.zeros(1, 704, 1280).cuda() + + return data + except Exception: + warnings.warn( + f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + self.wrong_number += 1 + print(self.wrong_number) + return self[np.random.randint(len(self.samples))] + + +if __name__ == "__main__": + dataset = Dataset( + dataset_dir="assets/example_training_data/", + sequence_interval=1, + num_frames=57, + video_size=[240, 360], + ) + + indices = [0, 13, 200, -1] + for idx in indices: + data = dataset[idx] + print( + ( + f"{idx=} " + f"{data['video'].sum()=}\n" + f"{data['video'].shape=}\n" + f"{data['video_name']=}\n" + f"{data['t5_text_embeddings'].shape=}\n" + "---" + ) + ) diff --git a/cosmos_transfer1/diffusion/training/functional/loss.py b/cosmos_transfer1/diffusion/training/functional/loss.py new file mode 100644 index 00000000..22d11006 --- /dev/null +++ b/cosmos_transfer1/diffusion/training/functional/loss.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union + +import torch + +from cosmos_transfer1.diffusion.functional.batch_ops import batch_mul + + +def create_per_sample_loss_mask( + loss_masking_cfg: dict, + data_batch: dict, + x_shape: Tuple[int], + dtype: torch.dtype, + device: Union[str, torch.device] = "cuda", +): + """ + Creates a per-sample loss mask based on the given configuration and input data batch. + + This function generates a dictionary of loss masks for each specified key in the loss masking configuration. + For keys present in both the configuration and the data batch, the corresponding data batch value is used. + For keys present only in the configuration, a tensor of zeros with the specified shape is created. + Additionally, it computes loss mask weights for each key based on the configuration values and adjusts them + based on the presence of certain keys in the data batch, such as "skip_face" and "object_loss_map". + + Note: + - The original `loss_masking_cfg` and `data_batch` are not modified by this function. + - For image data, it is assumed that the channel is always the first dimension. + - `skip_face` is for face regions that should be skipped during training, the key is provided so that we can generate + diverse human and avoid collapse to a single face given certain prompts. The issue happens for getty projects, + where face distribution in the dataset is high unbalanced that single man face can be shown in more than 100+ images. + + Parameters: + loss_masking_cfg (dict): Configuration for loss masking, specifying which keys to include and their weights. + data_batch (dict): The batch of data containing actual data points and potential mask indicators like "skip_face". + x_shape (tuple): The shape of the input data, used to initialize zero masks for keys not in the data batch. + dtype (torch.dtype): The data type for the tensors in the loss masks. + device (str, optional): The device on which to create the tensors. Defaults to 'cuda'. + + Returns: + dict: A dictionary containing combined loss masks adjusted according to the `loss_masking_cfg` and `data_batch`. + + Raises: + AssertionError: If "skip_face" is not present in `data_batch`. + + Note: `create_combined_loss_mask` is assumed to be a separate function that combines individual loss masks into a + single mask or set of masks based on the given parameters. Its behavior should be documented separately. + """ + loss_mask_data: dict = {} + for key in loss_masking_cfg: + if key not in data_batch: + loss_mask_data[key] = torch.zeros((x_shape[0], 1, x_shape[2], x_shape[3]), device=device) + else: + loss_mask_data[key] = data_batch[key] + + if "skip_face" not in data_batch: + # When skip_face is not there in data_dict, use 0 as default. This will not skip any sample. + data_batch["skip_face"] = torch.zeros((x_shape[0],), dtype=dtype, device=device) + + loss_mask_weight: dict = {} + for k, v in loss_masking_cfg.items(): + loss_mask_weight[k] = torch.tensor(v, device=device).expand(data_batch["skip_face"].size()) + + if "human_face_mask" in loss_mask_weight: + loss_mask_weight["human_face_mask"] = (1 - data_batch["skip_face"]) * loss_mask_weight["human_face_mask"] + + if "object_loss_map" in data_batch: + loss_mask_weight["object_loss_map"] = torch.ones(data_batch["object_loss_map"].shape[0], device=device) + + return create_combined_loss_mask(loss_mask_data, x_shape, dtype, device, loss_mask_weight) + + +def create_combined_loss_mask(data, x_shape, dtype, device="cuda", loss_masking=None): + """ + Creates a combined loss mask from multiple input masks. + + This function combines several loss masks into a single mask. In regions where masks overlap, + the highest value is assigned. Non-overlapping regions are assigned a default value of 1. + Regions with a mask value of zero are explicitly zeroed out, which is essential for padded loss calculations. + + Example: + Given the following masks and weights: + mask1: [0, 1, 1, 1, 0, 0], weight: 2 + mask2: [1, 0, 1, 0, 0, 0], weight: 4 + mask3: [0, 1, 0, 0, 0, 0], weight: 0 + The resulting combined loss mask would be: + [4, 0, 4, 2, 1, 1] + + Parameters: + data (dict): Contains the loss masks and their weights. + x_shape (tuple): The shape of the output mask. + dtype: The data type for the output mask. + device: The device on which the output mask will be allocated. + loss_masking: The loss masking weight configuration. + + Returns: + torch.Tensor: The combined loss mask. + """ + + loss_mask = torch.ones(x_shape, dtype=dtype, device=device) + zero_mask = torch.ones(x_shape, dtype=dtype, device=device) + + if loss_masking: + for key in loss_masking: + # Repeat mask along channel's dimension. ndim=4 for images. + repeat_dims = (1, x_shape[1]) + tuple([1] * (data[key].ndim - 2)) + mask_key = torch.tile(data[key], dims=repeat_dims) + weight_key = loss_masking[key] + + # handle zero weight case + is_zero_weight = (weight_key == 0).float()[:, None, None, None] + zero_mask = zero_mask * ( + (1 - is_zero_weight) * torch.ones(x_shape, dtype=dtype, device=device) + + is_zero_weight * (1 - mask_key.bool().float()) + ) + + # calculate weights + no_mask_region = (mask_key.bool() == 0).float() + loss_mask = batch_mul(mask_key, weight_key) + batch_mul(no_mask_region, loss_mask) + + loss_mask_final = loss_mask * zero_mask + return loss_mask_final diff --git a/cosmos_transfer1/diffusion/training/functional/lr_scheduler.py b/cosmos_transfer1/diffusion/training/functional/lr_scheduler.py new file mode 100644 index 00000000..007fe1d0 --- /dev/null +++ b/cosmos_transfer1/diffusion/training/functional/lr_scheduler.py @@ -0,0 +1,178 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import numpy as np + +from cosmos_transfer1.utils import distributed, log + + +class TeroPolyScheduler: + def __init__( + self, + total_Mimg: int, + batch_size: int, + ref_Mimg: Optional[int] = None, + ref_batches: float = 70e3 / 1024, + max_lr_ratio: Optional[float] = 1.0, + min_lr_ratio: Optional[float] = None, + rampup_Mimg: float = 0, + rampdown_Mimg: int = 0, + verbosity_interval: int = 0, + formula: str = "poly", + poly_exp: float = 0.5, + ): + self.total_Mimg = total_Mimg + self.batch_size = batch_size * distributed.get_world_size() + self.ref_Mimg = ref_Mimg or ref_batches * batch_size / 1e6 + self.ref_batches = ref_batches + self.max_lr_ratio = max_lr_ratio + self.min_lr_ratio = min_lr_ratio + self.rampup_Mimg = rampup_Mimg + self.rampdown_Mimg = rampdown_Mimg + self.verbosity_interval = verbosity_interval + self.formula = formula + self.poly_exp = poly_exp + + self._model = None + + @property + def model(self): + return self._model + + @model.setter + def model(self, model): + self._model = model + + def schedule(self, n, **kwargs): + cur_Mimg = getattr(self.model, "sample_counter", 0) / 1e6 + + if self.formula == "constant": + lr = 1.0 + elif self.formula == "poly": + lr = max(cur_Mimg / self.ref_Mimg, 1e-8) ** -self.poly_exp + else: + raise ValueError(f'Invalid learning rate formula "{self.formula}"') + + if self.max_lr_ratio is not None: + lr = min(lr, self.max_lr_ratio) + if self.min_lr_ratio is not None: + lr = max(lr, self.min_lr_ratio) + + if self.rampup_Mimg > 0 and cur_Mimg < self.rampup_Mimg: + lr *= cur_Mimg / self.rampup_Mimg + if self.rampdown_Mimg > 0 and cur_Mimg > self.total_Mimg - self.rampdown_Mimg: + lr *= (self.total_Mimg - cur_Mimg) / self.rampdown_Mimg + + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaWarmUpCosineScheduler: + """ + A learning rate scheduler that combines warm-up with a cosine decay schedule for multiple cycles. + It supports different configurations for each cycle, including the number of warm-up steps, minimum + and maximum scaling factors for the learning rate. + + The scheduler is intended to be used with a base learning rate of 1.0, where the actual learning + rate at any step is the base learning rate multiplied by the scaling factor computed by the scheduler. + + Parameters: + warm_up_steps (list[int]): List of integers where each element represents the number of warm-up + steps for the corresponding cycle. + f_min (list[float]): List of the minimum scaling factors for each cycle after warm-up. + f_max (list[float]): List of the maximum scaling factors at the start and end of each cosine cycle. + f_start (list[float]): List of starting scaling factors for each warm-up phase. + cycle_lengths (list[int]): List of the total lengths of each cycle, including warm-up steps. + verbosity_interval (int, optional): Interval of training steps at which to print current step and + scaling factor information. Set to 0 by default to disable verbosity. + + Examples: + >>> scheduler = LambdaWarmUpCosineScheduler2( + warm_up_steps=[10, 10], + f_min=[0.1, 0.1], + f_max=[1.0, 1.0], + f_start=[0.01, 0.01], + cycle_lengths=[50, 50], + verbosity_interval=10) + >>> for step in range(100): + >>> lr_multiplier = scheduler(step) + >>> print(f"Step {step}: LR Multiplier = {lr_multiplier}") + """ + + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0.0 + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + log.info(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler): + """ + Linear instead of cosine decay for the main part of the cycle. + """ + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + log.info(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / ( + self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] + ) + self.last_f = f + return f diff --git a/cosmos_transfer1/diffusion/training/modules/edm_sde.py b/cosmos_transfer1/diffusion/training/modules/edm_sde.py new file mode 100644 index 00000000..3d08a822 --- /dev/null +++ b/cosmos_transfer1/diffusion/training/modules/edm_sde.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from statistics import NormalDist + +import numpy as np +import torch + + +class EDMSDE: + def __init__( + self, + p_mean: float = -1.2, + p_std: float = 1.2, + sigma_max: float = 80.0, + sigma_min: float = 0.002, + ): + self.gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) + self.sigma_max = sigma_max + self.sigma_min = sigma_min + + def sample_t(self, batch_size: int) -> torch.Tensor: + cdf_vals = np.random.uniform(size=(batch_size)) + samples_interval_gaussian = [self.gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] + + log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") + return torch.exp(log_sigma) + + def marginal_prob(self, x0: torch.Tensor, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """This is trivial in the base class, but may be used by derived classes in a more interesting way""" + return x0, sigma diff --git a/cosmos_transfer1/diffusion/training/tensor_parallel.py b/cosmos_transfer1/diffusion/training/tensor_parallel.py new file mode 100644 index 00000000..c756c38e --- /dev/null +++ b/cosmos_transfer1/diffusion/training/tensor_parallel.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +import torch.distributed as dist +from torch.autograd import Function + + +class AllGather(Function): + @staticmethod + def forward(ctx, tensor, process_group): + world_size = dist.get_world_size(process_group) + ctx.world_size = world_size + ctx.rank = process_group.rank() + + gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)] + dist.all_gather(gathered_tensors, tensor.contiguous(), process_group) + return torch.cat(gathered_tensors, dim=0) + + @staticmethod + def backward(ctx, grad_output): + world_size = ctx.world_size + rank = ctx.rank + + # Split the gradient tensor + grad_chunks = grad_output.chunk(world_size) + + # Select the gradient chunk for the current rank + grad_input = grad_chunks[rank] + return grad_input, None + + +def gather_along_first_dim(tensor, process_group): + return AllGather.apply(tensor, process_group) + + +class Scatter(Function): + @staticmethod + def forward(ctx, tensor, process_group): + world_size = dist.get_world_size(process_group) + ctx.world_size = world_size + ctx.process_group = process_group + rank = process_group.rank() + + # Split the tensor + tensor_chunks = tensor.chunk(world_size) + + # Select the tensor chunk for the current rank + return tensor_chunks[rank] + + @staticmethod + def backward(ctx, grad_output): + world_size = ctx.world_size + process_group = ctx.process_group + + # Gather the gradient tensor + gathered_grads = [torch.zeros_like(grad_output) for _ in range(world_size)] + dist.all_gather(gathered_grads, grad_output.contiguous(), process_group) + return torch.cat(gathered_grads, dim=0), None + + +def scatter_along_first_dim(tensor, process_group): + return Scatter.apply(tensor, process_group) + + +if __name__ == "__main__": + # Torch global setup for distributed training + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group(world_size=world_size, rank=rank) + + # Create a tensor with gradients + x = torch.randn(10, 1, requires_grad=True, device="cuda") + + # Perform all_gather with gradient support + y = gather_along_first_dim(x, dist.group.WORLD) + print(f"{y.shape=}") + y = scatter_along_first_dim(y, dist.group.WORLD) + print(f"{y.shape=}") + + # Use the result in your computation + loss = y.sum() + loss.backward() + + # x.grad now contains the gradients + print(x.grad) diff --git a/cosmos_transfer1/diffusion/training/train.py b/cosmos_transfer1/diffusion/training/train.py new file mode 100644 index 00000000..40b02af2 --- /dev/null +++ b/cosmos_transfer1/diffusion/training/train.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import os + +import torch.distributed as dist +from loguru import logger as logging +from omegaconf import OmegaConf + +from cosmos_transfer1.diffusion.config.config import Config +from cosmos_transfer1.utils import log, misc +from cosmos_transfer1.utils.lazy_config import instantiate +from cosmos_transfer1.utils.config_helper import get_config_module, override +from cosmos_transfer1.utils.lazy_config.lazy import LazyConfig +from cosmos_transfer1.utils.parallel_state_helper import is_tp_cp_pp_rank0 + + +@misc.timer("instantiate model") +def instantiate_model(config: Config, trainer) -> None: + misc.set_random_seed(seed=config.trainer.seed, by_rank=False) + config.model_obj.config = config.model + if getattr(config.model, "fsdp_enabled", False): + assert config.trainer.distributed_parallelism == "fsdp", "FSDP model is only supported with FSDP trainer" + log.critical("FSDP enabled") + config.model_obj.fsdp_checkpointer = trainer.checkpointer + model = instantiate(config.model_obj) + config.model_obj.fsdp_checkpointer = None + else: + model = instantiate(config.model_obj) + config.model_obj.config = None + misc.set_random_seed(seed=config.trainer.seed, by_rank=True) + return model + + +def destroy_distributed(): + log.info("Destroying distributed environment...") + if dist.is_available() and dist.is_initialized(): + try: + dist.destroy_process_group() + except ValueError as e: + print(f"Error destroying default process group: {e}") + + +@logging.catch(reraise=True) +def launch(config: Config, args: argparse.Namespace) -> None: + # Check that the config is valid + config.validate() + # Freeze the config so developers don't change it during training. + config.freeze() # type: ignore + trainer = config.trainer.type(config) + # # Setup the miscellaneous stuff for reproducibility. + # log_reproducible_setup(config, args) + # Create the model + model = instantiate_model(config, trainer) + model.on_model_init_end() + # Create the dataloaders. + if args.mp0_only_dl: + log.critical( + "Using only tp_cp_pp_rank0 dataloader for faster dataloading! Make sure val dl is mock and mock data has same keys as real data." + ) + raise NotImplementedError( + "mp0_only_dl is not implemented correctly! Please revisit this code and propose a more robust impl that raise error timely! It does not do necessary check before training to confirm it can work with image / video data. Current impl is problematic for image training." + ) + if is_tp_cp_pp_rank0() or not args.mp0_only_dl: + dataloader_train = instantiate(config.dataloader_train) + else: + dataloader_train = instantiate(config.dataloader_val) + dataloader_val = instantiate(config.dataloader_val) + # Start training + trainer.train( + model, + dataloader_train, + dataloader_val, + ) + destroy_distributed() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Training") + parser.add_argument( + "--config", + default="cosmos_transfer1/diffusion/posttrain/config/config.py", + help="Path to the config file", + ) + parser.add_argument( + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument( + "--dryrun", + action="store_true", + help="Do a dry run without training. Useful for debugging the config.", + ) + parser.add_argument( + "--mp0_only_dl", + action="store_true", + help="Use only model parallel rank 0 dataloader for faster dataloading! Make sure mock data has same keys as real data.", + ) + args = parser.parse_args() + config_module = get_config_module(args.config) + config = importlib.import_module(config_module).make_config() + config = override(config, args.opts) + if args.dryrun: + os.makedirs(config.job.path_local, exist_ok=True) + LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") + print(OmegaConf.to_yaml(OmegaConf.load(f"{config.job.path_local}/config.yaml"))) + print(f"{config.job.path_local}/config.yaml") + else: + # Launch the training job. + launch(config, args) diff --git a/cosmos_transfer1/diffusion/training/utils/optim_instantiate.py b/cosmos_transfer1/diffusion/training/utils/optim_instantiate.py new file mode 100644 index 00000000..c12dca27 --- /dev/null +++ b/cosmos_transfer1/diffusion/training/utils/optim_instantiate.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hydra +import torch +from torch import nn + +from cosmos_transfer1.utils import log +from cosmos_transfer1.utils.fused_adam import FusedAdam + + +def get_regular_param_group(net: nn.Module): + """ + seperate the parameters of the network into two groups: decay and no_decay. + based on nano_gpt codebase. + """ + param_dict = {pn: p for pn, p in net.named_parameters()} + param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + return decay_params, nodecay_params + + +def get_base_optimizer( + model: nn.Module, + lr: float, + weight_decay: float, + optim_type: str = "adamw", + sharding: bool = False, + **kwargs, +) -> torch.optim.Optimizer: + net_decay_param, net_nodecay_param = get_regular_param_group(model) + + num_decay_params = sum(p.numel() for p in net_decay_param) + num_nodecay_params = sum(p.numel() for p in net_nodecay_param) + net_param_total = num_decay_params + num_nodecay_params + log.critical(f"total num parameters : {net_param_total:,}") + + param_group = [ + { + "params": net_decay_param + net_nodecay_param, + "lr": lr, + "weight_decay": weight_decay, + }, + ] + + if optim_type == "adamw": + opt_cls = torch.optim.AdamW + elif optim_type == "fusedadam": + opt_cls = FusedAdam + else: + raise ValueError(f"Unknown optimizer type: {optim_type}") + + return opt_cls(param_group, **kwargs) + + +def get_base_scheduler( + optimizer: torch.optim.Optimizer, + model: nn.Module, + scheduler_config: dict, +): + net_scheduler = hydra.utils.instantiate(scheduler_config) + net_scheduler.model = model + + return torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=[ + net_scheduler.schedule, + ], + ) diff --git a/cosmos_transfer1/utils/callback.py b/cosmos_transfer1/utils/callback.py new file mode 100644 index 00000000..cbb2ad8f --- /dev/null +++ b/cosmos_transfer1/utils/callback.py @@ -0,0 +1,457 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +import warnings +from typing import TYPE_CHECKING, Any, Callable, Optional, List + +import omegaconf +import torch +import torch.utils.data +import tqdm +from megatron.core import parallel_state +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from cosmos_transfer1.utils import distributed, log +from cosmos_transfer1.utils.lazy_config import instantiate +from cosmos_transfer1.utils.misc import get_local_tensor_if_DTensor + +if TYPE_CHECKING: + from cosmos_transfer1.utils.config import Config + from cosmos_transfer1.utils.model import Model + from cosmos_transfer1.utils.trainer import Trainer + + +class CallBackGroup: + """A class for hosting a collection of callback objects. + + It is used to execute callback functions of multiple callback objects with the same method name. + When callbackgroup.func(args) is executed, internally it loops through the objects in self._callbacks and runs + self._callbacks[0].func(args), self._callbacks[1].func(args), etc. The method name and arguments should match. + + Attributes: + _callbacks (list[Callback]): List of callback objects. + """ + + def __init__(self, config: Config, trainer: Trainer) -> None: + """Initializes the list of callback objects. + + Args: + config (Config): The config object for the codebase. + trainer (Trainer): The main trainer. + """ + self._callbacks = [] + callback_configs = config.trainer.callbacks + if callback_configs: + if isinstance(callback_configs, list) or isinstance(callback_configs, omegaconf.listconfig.ListConfig): + warnings.warn( + "The 'config.trainer.callbacks' parameter should be a dict instead of a list. " + "Please update your code", + DeprecationWarning, + stacklevel=2, + ) + callback_configs = {f"callback_{i}": v for i, v in enumerate(callback_configs)} + for callback_name, current_callback_cfg in callback_configs.items(): + if "_target_" not in current_callback_cfg: + log.critical( + f"Callback {callback_name} is missing the '_target_' field. \n SKip {current_callback_cfg}" + ) + continue + log.critical(f"Instantiating callback {callback_name}: {current_callback_cfg}") + _callback = instantiate(current_callback_cfg) + assert isinstance(_callback, Callback), f"{current_callback_cfg} is not a valid callback." + _callback.config = config + _callback.trainer = trainer + self._callbacks.append(_callback) + + def __getattr__(self, method_name: str) -> Callable: + """Loops through the callback objects to call the corresponding callback function. + + Args: + method_name (str): Callback method name. + """ + + def multi_callback_wrapper(*args, **kwargs) -> None: + for callback in self._callbacks: + assert hasattr(callback, method_name) + method = getattr(callback, method_name) + assert callable(method) + _ = method(*args, **kwargs) + + return multi_callback_wrapper + + +class Callback: + """The base class for all callbacks. + + All callbacks should inherit from this class and adhere to the established method names and signatures. + """ + + def __init__(self, config: Optional["Config"] = None, trainer: Optional["Trainer"] = None): + """Initializes a Callback object. + + Args: + config (Optional[Config]): The configuration object for the codebase, if available. + trainer (Optional[Trainer]): The main trainer handling the training loop, if available. + + Notes: + The config and trainer parameters are optional to maintain backward compatibility. + In future releases, these parameters will be removed. Upon using these parameters, a deprecation + warning will be issued. + + """ + if config is not None or trainer is not None: + warnings.warn( + "The 'config' and 'trainer' parameters are deprecated and will be removed in a future release. " + "Please update your code to create Callback instances without these parameters.", + DeprecationWarning, + stacklevel=2, + ) + del config, trainer + + def on_train_start(self, model: Model, iteration: int = 0) -> None: + pass + + def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + pass + + def on_before_forward(self, iteration: int = 0) -> None: + pass + + def on_after_forward(self, iteration: int = 0) -> None: + pass + + def on_before_backward( + self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0 + ) -> None: + pass + + def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None: + pass + + def on_before_dataloading(self, iteration: int = 0) -> None: + pass + + def on_after_dataloading(self, iteration: int = 0) -> None: + pass + + def on_optimizer_init_start(self) -> None: + pass + + def on_optimizer_init_end(self) -> None: + pass + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + pass + + def on_before_zero_grad( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + iteration: int = 0, + ) -> None: + pass + + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + pass + + def on_validation_start( + self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0 + ) -> None: + pass + + def on_validation_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + pass + + def on_validation_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + pass + + def on_validation_end(self, model: Model, iteration: int = 0) -> None: + pass + + def on_load_checkpoint_start(self, model: Model) -> None: + pass + + def on_load_checkpoint_end(self, model: Model) -> None: + pass + + def on_load_checkpoint(self, model: Model, state_dict: dict[Any]) -> None: + pass + + def on_save_checkpoint_start(self, model: Model, iteration: int = 0) -> None: + pass + + def on_save_checkpoint_end(self, model: Model, iteration: int = 0) -> None: + pass + + def on_save_checkpoint_success(self, iteration: int = 0) -> None: + pass + + def on_save_checkpoint(self, model: Model, state_dict: dict[Any]) -> None: + pass + + def on_train_end(self, model: Model, iteration: int = 0) -> None: + pass + + def on_app_end(self) -> None: + pass + + +class EMAModelCallback(Callback): + """The callback class for tracking EMA model weights.""" + + def on_train_start(self, model: Model, iteration: int = 0) -> None: + # Set up the EMA model weight tracker. + if model.config.ema.enabled: + assert hasattr(model, "ema"), "EMA should be initialized from Model" + # EMA model must be kept in FP32 precision. + model.ema = model.ema.to(dtype=torch.float32) + else: + assert not hasattr(model, "ema"), "There should be no EMA initialized." + + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + # Update the EMA model with the new regular weights. + if model.config.ema.enabled: + model.ema.update_average(model, iteration) + + +class ProgressBarCallback(Callback): + """The callback class for visualizing the training/validation progress bar in the console.""" + + @distributed.rank0_only + def on_train_start(self, model: Model, iteration: int = 0) -> None: + self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training") + + @distributed.rank0_only + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + self.train_pbar.update() + + @distributed.rank0_only + def on_validation_start( + self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0 + ) -> None: + if self.config.trainer.max_val_iter is not None: + num_iter = self.config.trainer.max_val_iter + else: + num_iter = len(dataloader_val) + assert num_iter is not None and num_iter > 0, f"Invalid number of validation iterations: {num_iter}" + self.val_pbar = tqdm.trange(num_iter, desc="Validating", position=1, leave=False) + + @distributed.rank0_only + def on_validation_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + self.val_pbar.update() + + @distributed.rank0_only + def on_validation_end(self, model: Model, iteration: int = 0) -> None: + self.val_pbar.close() + + @distributed.rank0_only + def on_train_end(self, model: Model, iteration: int = 0) -> None: + self.trainer.checkpointer.finalize() + self.train_pbar.close() + + +class IterationLoggerCallback(Callback): + """The callback class for visualizing the training/validation progress bar in the console.""" + + @distributed.rank0_only + def on_train_start(self, model: Model, iteration: int = 0) -> None: + # self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training") + self.start_iteration_time = time.time() + self.elapsed_iteration_time = 0 + + @distributed.rank0_only + def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + self.start_iteration_time = time.time() + + @distributed.rank0_only + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + self.elapsed_iteration_time += time.time() - self.start_iteration_time + + if iteration % self.config.trainer.logging_iter == 0: + avg_time = self.elapsed_iteration_time / self.config.trainer.logging_iter + log.info(f"Iteration: {iteration}, average iter time: {avg_time:2f}, total loss {loss.item():4f}") + + self.elapsed_iteration_time = 0 + + +@torch.jit.script +def _fused_nan_to_num(params: List[torch.Tensor]): + for param in params: + torch.nan_to_num(param, nan=0.0, posinf=0.0, neginf=0.0, out=param) + + +class GradClip(Callback): + def __init__( + self, clip_norm=1.0, force_finite: bool = True, model_key: Optional[str] = None, fsdp_enabled: bool = False + ): + self.clip_norm = clip_norm + self.force_finite = force_finite + self.model_key = model_key + self.fsdp_enabled = fsdp_enabled + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + del optimizer, scheduler + if isinstance(model_ddp, distributed.DistributedDataParallel): + model = model_ddp.module + else: + model = model_ddp + + # select sub-network if specified + if self.model_key is not None: + items = self.model_key.split(".") + for item in items: + model = getattr(model, item) + + if self.force_finite: + params = [] + for param in model.parameters(): + if param.grad is not None: + params.append(param.grad) + # torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) + _fused_nan_to_num(params) + + # check if FSDP is used + # total_norm + if isinstance(model, FSDP) and self.fsdp_enabled: + model.clip_grad_norm_(self.clip_norm) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True) + + + +class GradClipCallback(Callback): + """The callback class for gradient clipping.""" + + def __init__( + self, + config: Optional["Config"] = None, + trainer: Optional["Trainer"] = None, + grad_clip_norm: float = 1.0, + ): + super().__init__(config, trainer) + self.grad_clip_norm = grad_clip_norm + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + grad_scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model_ddp.module.parameters(), max_norm=self.grad_clip_norm) + + +class LowPrecisionCallback(Callback): + """The callback class handling low precision training""" + + def __init__(self, update_iter: int, config: Optional["Config"] = None, trainer: Optional["Trainer"] = None): + super().__init__(config, trainer) + self.update_iter = update_iter + + def on_train_start(self, model: Model, iteration: int = 0) -> None: + assert model.precision in [ + torch.bfloat16, + torch.float16, + torch.half, + ], "LowPrecisionCallback must use a low precision dtype." + self.precision_type = model.precision + + def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + for k, v in data.items(): + if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]): + data[k] = v.to(dtype=self.precision_type) + + def on_validation_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + for k, v in data.items(): + if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]): + data[k] = v.to(dtype=self.precision_type) + + def on_before_zero_grad( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + iteration: int = 0, + ) -> None: + if iteration % self.update_iter == 0: + if getattr(optimizer, "master_weights", False): + params, master_params = [], [] + for group, group_master in zip(optimizer.param_groups, optimizer.param_groups_master): + for p, p_master in zip(group["params"], group_master["params"]): + params.append(get_local_tensor_if_DTensor(p.data)) + master_params.append(p_master.data) + torch._foreach_copy_(params, master_params) diff --git a/cosmos_transfer1/utils/checkpointer.py b/cosmos_transfer1/utils/checkpointer.py new file mode 100644 index 00000000..2c8617ff --- /dev/null +++ b/cosmos_transfer1/utils/checkpointer.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import threading +from typing import TYPE_CHECKING + +import torch + +from cosmos_transfer1.utils import callback, distributed, log, misc +from cosmos_transfer1.utils.model import Model + +if TYPE_CHECKING: + from cosmos_transfer1.utils.config import CheckpointConfig, JobConfig + + +class Checkpointer: + """The checkpointer class. Supports checkpoint saving/loading to local disk.""" + + def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): + """Constructor of the checkpointer. + + Args: + config_checkpoint (CheckpointConfig): The config object for the checkpointer. + """ + # Set the callback functions. + self.callbacks = callbacks + self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints" + self.strict_resume = config_checkpoint.strict_resume + self.load_path = config_checkpoint.load_path or None + self.load_training_state = config_checkpoint.load_training_state + self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state + self.save_thread = None + + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> None: + """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + iteration (int): Current iteration number. + """ + self.callbacks.on_save_checkpoint_start(model, iteration) + + checkpoint_file = f"iter_{iteration:09}.pt" + + if distributed.get_rank() == 0: + state_dict = dict( + model=model.state_dict(), + optimizer=optimizer.state_dict(), + scheduler=scheduler.state_dict(), + grad_scaler=grad_scaler.state_dict(), + iteration=iteration, + ) + state_dict = misc.to(state_dict, device="cpu") + self.callbacks.on_save_checkpoint(model, state_dict=state_dict) + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker_local, + daemon=False, + args=(state_dict, checkpoint_file, distributed.get_rank()), + ) + self.save_thread.start() + + # Note: Checkpoints are saved on a separate thread and this callback is not accurate. + # Please check logs from on_save_checkpoint_success() for better accuracy + self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) + + @misc.timer("checkpoint saving (local)") + def _save_worker_local(self, state_dict: dict[str, torch.Tensor], checkpoint_file: str, rank: int = 0) -> None: + """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training). + + Args: + state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler. + checkpoint_file (str): The file name of the model checkpoint. + rank (int): GPU device (default: 0). + """ + checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file) + os.makedirs(self.checkpoint_dir_local, exist_ok=True) + try: + torch.save(state_dict, checkpoint_path) + if rank == 0: + self._write_latest_checkpoint_file(checkpoint_file) + log.success(f"Saved checkpoint (local): {checkpoint_path}") + iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) + self.callbacks.on_save_checkpoint_success(iteration=iteration) + except Exception as e: # noqa: BLE001 + log.exception(f"Checkpoint failed to save (local): {e}") + + @misc.timer("checkpoint loading") + def load( + self, + model: Model, + optimizer: torch.optim.Optimizer | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> int: + """Load network weights and optimizer states from a checkpoint in a single process. + + The priority of the checkpoint loading logic is: + 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. + 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. + - This is typically used for inference mode. + - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. + 3. If none of the above, randomly initialize the model parameters and train from scratch. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). + scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). + grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). + + Returns: + iteration (int): the iteration number to start/resume from. + """ + self.callbacks.on_load_checkpoint_start(model) + + latest_checkpoint_file = self._read_latest_checkpoint_file() + if latest_checkpoint_file is not None: + # 1. Resume training from latest_checkpoint.txt under the same name. + checkpoint_dir = self.checkpoint_dir_local + checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) + resume = True + only_resume_scheduler = True + else: + if self.load_path: + # 2. Load the module weights specified by config_checkpoint.path. + checkpoint_path = self.load_path + resume = self.load_training_state + only_resume_scheduler = self.only_load_scheduler_state + else: + # 3. Randomly initialize the model parameters and train from scratch. + checkpoint_path = None + resume = False + only_resume_scheduler = False + # Load checkpoint. + if checkpoint_path is not None: + self._check_checkpoint_exists(checkpoint_path) + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + self.callbacks.on_load_checkpoint(model, state_dict=state_dict) + # Load the state dicts. + log.info("- Loading the model...") + if "model" in state_dict: + model.load_state_dict(state_dict["model"], strict=self.strict_resume) + else: + model.load_state_dict(state_dict, strict=self.strict_resume) + if resume or only_resume_scheduler: + iteration = state_dict["iteration"] + assert scheduler + log.info("- Loading the scheduler...") + scheduler.load_state_dict(state_dict["scheduler"]) + scheduler.last_epoch = iteration + else: + iteration = 0 + if resume: + assert optimizer + log.info("- Loading the optimizer...") + optimizer.load_state_dict(state_dict["optimizer"]) + log.info("- Loading the gradient scaler...") + grad_scaler.load_state_dict(state_dict["grad_scaler"]) + log.success(f"Done with loading the checkpoint (iteration {iteration}).") + else: + log.success("Done with loading the checkpoint.") + else: + # Checkpoint not found and not specified. We will train everything from scratch. + iteration = 0 + log.info("Training from scratch.") + torch.cuda.empty_cache() + + self.callbacks.on_load_checkpoint_end(model) + + return iteration + + def _read_latest_checkpoint_file(self) -> str | None: + """Get the file name of the latest saved checkpoint. If it doesn't exist, return None. + + Returns: + checkpoint_file (str | None): file name of the latest saved checkpoint. + """ + checkpoint_file = None + latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") + if os.path.isfile(latest_path): + checkpoint_file = open(latest_path).read().strip() + return checkpoint_file + + def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: + """Track the file name of the latest saved checkpoint. + + Args: + checkpoint_file (str): file name of the latest saved checkpoint. + """ + content = f"{checkpoint_file}\n" + latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") + with open(latest_path, "w") as file: + file.write(content) + + def _check_checkpoint_exists(self, checkpoint_path: str) -> None: + """If the file checkpoint_path does not exist, raise an error. + + Args: + checkpoint_path (str): full path to the checkpoint. + """ + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"File not found (local): {checkpoint_path}") + + def finalize(self) -> None: + """Finalize the checkpointer.""" + if self.save_thread: + self.save_thread.join() diff --git a/cosmos_transfer1/utils/config.py b/cosmos_transfer1/utils/config.py index 88d8d1b4..e91b9bf7 100644 --- a/cosmos_transfer1/utils/config.py +++ b/cosmos_transfer1/utils/config.py @@ -15,16 +15,27 @@ from __future__ import annotations -from typing import Any, TypeVar +import os +from typing import Any, Dict, Optional, Type, TypeVar, Union import attrs +import torch +try: + from megatron.core import ModelParallelConfig + + USE_MEGATRON = True +except ImportError: + USE_MEGATRON = False + print("Megatron-core is not installed.") + +from cosmos_transfer1.utils.lazy_config import LazyCall as L from cosmos_transfer1.utils.lazy_config import LazyDict from cosmos_transfer1.utils.misc import Color +from cosmos_transfer1.utils.callback import EMAModelCallback, ProgressBarCallback T = TypeVar("T") - def _is_attrs_instance(obj: object) -> bool: """ Helper function to check if an object is an instance of an attrs-defined class. @@ -140,6 +151,129 @@ class JobConfig: def path(self) -> str: return f"{self.project}/{self.group}/{self.name}" + @property + def path_local(self) -> str: + local_root = os.environ.get("OUTPUT_ROOT", "checkpoints") + return f"{local_root}/{self.path}" + + +@make_freezable +@attrs.define(slots=False) +class EMAConfig: + # Enable tracking a set of exponential moving average (EMA) weights. + enabled: bool = False + # EMA decay rate. + beta: float = 0.9999 + # Enable removing "_orig_mod-" from buffer names that is added by torch.compile + torch_compile_buffer_renaming: bool = False + + +@make_freezable +@attrs.define(slots=False) +class DDPConfig: + # Traverse the computation graph to find parameters that don't receive gradients. + find_unused_parameters: bool = False + # Set to True if the computation graph does not change during the whole training loop. + static_graph: bool = True + # Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere. + broadcast_buffers: bool = True + + +@make_freezable +@attrs.define(slots=False) +class CuDNNConfig: + # Set to True for better reproducibility of the results (only using deterministic cudnn functions). + deterministic: bool = False + # If set to True, cudnn will benchmark several algorithms and pick the fastest one. + benchmark: bool = True + + +@make_freezable +@attrs.define(slots=False) +class JITConfig: + # Enable exporting a JIT compiled model. + enabled: bool = False + # Input tensor shape, for example input. + input_shape: Union[list[int], None] = None + # Device to compile onto. + device: str = "cuda" + # # Data type to compile onto. + dtype: str = "bfloat16" + # Strict mode for PyTorch JIT. + strict: bool = True + + +@make_freezable +@attrs.define(slots=False) +class CheckpointConfig: + # possible checkpoint class + type: Optional[Dict] = None + # for dcp, whether to use async mode + dcp_async_mode_enabled: bool = False + # Save the checkpoint every N iterations. + save_iter: int = 999999999 + # Path of model weights to resume the checkpoint from. + load_path: str = "" + # Whether to load the training states (optimizer/scheduler/grad-scaler) from the checkpoint path. + load_training_state: bool = False + # Whether to load the scheduler state only from the checkpoint path. If load_training_state is True, this will be ignored. + only_load_scheduler_state: bool = False + # Load state_dict to the models in strict mode. + strict_resume: bool = True + # Print detailed information during checkpoint saving/loading. + verbose: bool = True + # Configs for JIT compiling EMA model. + jit: JITConfig = attrs.field(factory=JITConfig) + # keys not to resume from the checkpoint, choices: ["model", "optim", "scheduler", "trainer"] + keys_not_to_resume: list[str] = [] + # Whether to use the local filesystem for broadcasting checkpoint data (used for Tensor Parallel Checkpointer). + broadcast_via_filesystem: bool = False + load_ema_to_reg: bool = False + + + +@make_freezable +@attrs.define(slots=False) +class TrainerConfig: + from cosmos_transfer1.utils.trainer import Trainer + + type: Type[Trainer] = Trainer + # Set the callback class. + # Defaults to the callbacks below. + callbacks: LazyDict = LazyDict( + dict( + ema=L(EMAModelCallback)(), + progress_bar=L(ProgressBarCallback)(), + ) + ) + # distributed parallelism strategy + distributed_parallelism: str = "ddp" + # Distributed data parallel configs. + ddp: DDPConfig = attrs.field(factory=DDPConfig) + # cuDNN configs. + cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig) + # Set the random seed. + seed: int = 0 + # Gradient scaler arguments (for torch.amp.GradScaler). + grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False)) + # Maximum number of iterations to train the model. + max_iter: int = 999999999 + # Maximum number of iterations to validate the model. If None, validate on the entire dataset. + max_val_iter: int | None = None + # How often we log the training stats. + logging_iter: int = 100 + # Whether we want to run the validation routines. + run_validation: bool = True + # How often we evaluate on the validation set. + validation_iter: int = 999999999 + # Kill the process after N seconds since the last iteration (usually means dead job). + timeout_period: int = 999999999 + # Tensor memory organization format. + memory_format: torch.memory_format = torch.preserve_format + # Gradient accumulation (update step every N iteration). + grad_accum_iter: int = 1 + # # Profiling config + # profiling: Profiling = attrs.field(factory=Profiling) @make_freezable @attrs.define(slots=False) @@ -151,10 +285,35 @@ class Config: # Model configs. model: LazyDict + # Optimizer configs. + optimizer: LazyDict = LazyDict(dict(dummy=None)) + # Scheduler configs. + scheduler: LazyDict = LazyDict(dict(dummy=None)) + # Training data configs. + dataloader_train: LazyDict = LazyDict(dict(dummy=None)) + # Validation data configs. + dataloader_val: LazyDict = LazyDict(dict(dummy=None)) + # Training job configs. job: JobConfig = attrs.field(factory=JobConfig) + # Trainer configs. + trainer: TrainerConfig = attrs.field(factory=TrainerConfig) + + # Megatron-Core configs + if USE_MEGATRON: + # Megatron-Core configs + model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig) + else: + model_parallel: None = None + + # Checkpointer configs. + checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig) + + def pretty_print(self, use_color: bool = False) -> str: + return _pretty_print_attrs_instance(self, 0, use_color) + def to_dict(self) -> dict[str, Any]: return attrs.asdict(self) diff --git a/cosmos_transfer1/utils/easy_io/__init__.py b/cosmos_transfer1/utils/easy_io/__init__.py new file mode 100644 index 00000000..3159bfe6 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_transfer1/utils/easy_io/backends/__init__.py b/cosmos_transfer1/utils/easy_io/backends/__init__.py new file mode 100644 index 00000000..86481ffa --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/backends/__init__.py @@ -0,0 +1,13 @@ +from cosmos_transfer1.utils.easy_io.backends.base_backend import BaseStorageBackend +from cosmos_transfer1.utils.easy_io.backends.http_backend import HTTPBackend +from cosmos_transfer1.utils.easy_io.backends.local_backend import LocalBackend +from cosmos_transfer1.utils.easy_io.backends.registry_utils import backends, prefix_to_backends, register_backend + +__all__ = [ + "BaseStorageBackend", + "LocalBackend", + "HTTPBackend", + "register_backend", + "backends", + "prefix_to_backends", +] diff --git a/cosmos_transfer1/utils/easy_io/backends/base_backend.py b/cosmos_transfer1/utils/easy_io/backends/base_backend.py new file mode 100644 index 00000000..2db3b921 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/backends/base_backend.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path as osp +from abc import ABCMeta, abstractmethod + + +def mkdir_or_exist(dir_name, mode=0o777): + if dir_name == "": + return + dir_name = osp.expanduser(dir_name) + os.makedirs(dir_name, mode=mode, exist_ok=True) + + +def has_method(obj, method): + return hasattr(obj, method) and callable(getattr(obj, method)) + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: :meth:`get()` and + :meth:`get_text()`. + + - :meth:`get()` reads the file as a byte stream. + - :meth:`get_text()` reads the file as texts. + """ + + # a flag to indicate whether the backend can create a symlink for a file + # This attribute will be deprecated in future. + _allow_symlink = False + + @property + def allow_symlink(self): + return self._allow_symlink + + @property + def name(self): + return self.__class__.__name__ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass diff --git a/cosmos_transfer1/utils/easy_io/backends/http_backend.py b/cosmos_transfer1/utils/easy_io/backends/http_backend.py new file mode 100644 index 00000000..8ed64251 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/backends/http_backend.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Union +from urllib.request import urlopen + +from cosmos_transfer1.utils.easy_io.backends.base_backend import BaseStorageBackend + + +class HTTPBackend(BaseStorageBackend): + """HTTP and HTTPS storage bachend.""" + + def get(self, filepath: str) -> bytes: + """Read bytes from a given ``filepath``. + + Args: + filepath (str): Path to read data. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> backend = HTTPBackend() + >>> backend.get('http://path/of/file') + b'hello world' + """ + return urlopen(filepath).read() + + def get_text(self, filepath, encoding="utf-8") -> str: + """Read text from a given ``filepath``. + + Args: + filepath (str): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> backend = HTTPBackend() + >>> backend.get_text('http://path/of/file') + 'hello world' + """ + return urlopen(filepath).read().decode(encoding) + + @contextmanager + def get_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath`` to a local temporary directory, + and return the temporary path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Yields: + Iterable[str]: Only yield one temporary path. + + Examples: + >>> backend = HTTPBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with backend.get_local_path('http://path/of/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) diff --git a/cosmos_transfer1/utils/easy_io/backends/local_backend.py b/cosmos_transfer1/utils/easy_io/backends/local_backend.py new file mode 100644 index 00000000..a99247f9 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/backends/local_backend.py @@ -0,0 +1,550 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +import os.path as osp +import shutil +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Iterator, Optional, Tuple, Union + +from cosmos_transfer1.utils.easy_io.backends.base_backend import BaseStorageBackend, mkdir_or_exist + + +class LocalBackend(BaseStorageBackend): + """Raw local storage backend.""" + + _allow_symlink = True + + def get(self, filepath: Union[str, Path]) -> bytes: + """Read bytes from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.get(filepath) + b'hello world' + """ + with open(filepath, "rb") as f: + value = f.read() + return value + + def get_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str: + """Read text from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.get_text(filepath) + 'hello world' + """ + with open(filepath, encoding=encoding) as f: + text = f.read() + return text + + def put(self, obj: Union[bytes, io.BytesIO], filepath: Union[str, Path]) -> None: + """Write bytes to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.put(b'hello world', filepath) + """ + mkdir_or_exist(osp.dirname(filepath)) + if isinstance(obj, io.BytesIO): + obj.seek(0) + obj = obj.getvalue() + with open(filepath, "wb") as f: + f.write(obj) + + def put_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None: + """Write text to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.put_text('hello world', filepath) + """ + mkdir_or_exist(osp.dirname(filepath)) + with open(filepath, "w", encoding=encoding) as f: + f.write(obj) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.exists(filepath) + True + """ + return osp.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/dir' + >>> backend.isdir(filepath) + True + """ + return osp.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.isfile(filepath) + True + """ + return osp.isfile(filepath) + + def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: + r"""Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of \*filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + + Examples: + >>> backend = LocalBackend() + >>> filepath1 = '/path/of/dir1' + >>> filepath2 = 'dir2' + >>> filepath3 = 'path/of/file' + >>> backend.join_path(filepath1, filepath2, filepath3) + '/path/of/dir/dir2/path/of/file' + """ + return osp.join(filepath, *filepaths) + + @contextmanager + def get_local_path( + self, + filepath: Union[str, Path], + ) -> Generator[Union[str, Path], None, None]: + """Only for unified API and does nothing. + + Args: + filepath (str or Path): Path to be read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Examples: + >>> backend = LocalBackend() + >>> with backend.get_local_path('abc/def.jpg') as path: + ... # do something here + """ + yield filepath + + def copyfile( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Copy a file src to dst and return the destination file. + + src and dst should have the same prefix. If dst specifies a directory, + the file will be copied into dst using the base filename from src. If + dst specifies a file that already exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to dst. + + Returns: + str: The destination file. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to '/path1/of/dir/file' + >>> backend.copyfile(src, dst) + '/path1/of/dir/file' + """ + return shutil.copy(src, dst) + + def copytree( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a + directory named dst and return the destination directory. + + src and dst should have the same prefix and dst must not already exist. + + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to dst. + + Returns: + str: The destination directory. + + Raises: + FileExistsError: If dst had already existed, a FileExistsError will + be raised. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree(src, dst) + '/path/of/dir2' + """ + return shutil.copytree(src, dst) + + def copyfile_from_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Copy a local file src to dst and return the destination file. Same + as :meth:`copyfile`. + + Args: + src (str or Path): A local file to be copied. + dst (str or Path): Copy file to dst. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile_from_local(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to + >>> backend.copyfile_from_local(src, dst) + '/path1/of/dir/file' + """ + return self.copyfile(src, dst) + + def copytree_from_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a + directory named dst and return the destination directory. Same as + :meth:`copytree`. + + Args: + src (str or Path): A local directory to be copied. + dst (str or Path): Copy directory to dst. + + Returns: + str: The destination directory. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree_from_local(src, dst) + '/path/of/dir2' + """ + return self.copytree(src, dst) + + def copyfile_to_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + dst_type: Optional[str] = None, + ) -> str: + """Copy the file src to local dst and return the destination file. Same + as :meth:`copyfile`. + + If dst specifies a directory, the file will be copied into dst using + the base filename from src. If dst specifies a file that already + exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to to local dst. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile_to_local(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to + >>> backend.copyfile_to_local(src, dst) + '/path1/of/dir/file' + """ + return self.copyfile(src, dst) + + def copytree_to_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a local + directory named dst and return the destination directory. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to local dst. + backend_args (dict, optional): Arguments to instantiate the + prefix of uri corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree_from_local(src, dst) + '/path/of/dir2' + """ + return self.copytree(src, dst) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + + Raises: + IsADirectoryError: If filepath is a directory, an IsADirectoryError + will be raised. + FileNotFoundError: If filepath does not exist, an FileNotFoundError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.remove(filepath) + """ + if not self.exists(filepath): + raise FileNotFoundError(f"filepath {filepath} does not exist") + + if self.isdir(filepath): + raise IsADirectoryError("filepath should be a file") + + os.remove(filepath) + + def rmtree(self, dir_path: Union[str, Path]) -> None: + """Recursively delete a directory tree. + + Args: + dir_path (str or Path): A directory to be removed. + + Examples: + >>> dir_path = '/path/of/dir' + >>> backend.rmtree(dir_path) + """ + shutil.rmtree(dir_path) + + def copy_if_symlink_fails( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> bool: + """Create a symbolic link pointing to src named dst. + + If failed to create a symbolic link pointing to src, directly copy src + to dst instead. + + Args: + src (str or Path): Create a symbolic link pointing to src. + dst (str or Path): Create a symbolic link named dst. + + Returns: + bool: Return True if successfully create a symbolic link pointing + to src. Otherwise, return False. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> backend.copy_if_symlink_fails(src, dst) + True + >>> src = '/path/of/dir' + >>> dst = '/path1/of/dir1' + >>> backend.copy_if_symlink_fails(src, dst) + True + """ + try: + os.symlink(src, dst) + return True + except Exception: + if self.isfile(src): + self.copyfile(src, dst) + else: + self.copytree(src, dst) + return False + + def list_dir_or_file( + self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False, + ) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str or Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix that we are + interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the directory. + Defaults to False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + + Examples: + >>> backend = LocalBackend() + >>> dir_path = '/path/of/dir' + >>> # list those files and directories in current directory + >>> for file_path in backend.list_dir_or_file(dir_path): + ... print(file_path) + >>> # only list files + >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False): + ... print(file_path) + >>> # only list directories + >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False): + ... print(file_path) + >>> # only list files ending with specified suffixes + >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): + ... print(file_path) + >>> # list all files and directory recursively + >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True): + ... print(file_path) + """ # noqa: E501 + if list_dir and suffix is not None: + raise TypeError("`suffix` should be None when `list_dir` is True") + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError("`suffix` must be a string or tuple of strings") + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith(".") and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + if (suffix is None or rel_path.endswith(suffix)) and list_file: + yield rel_path + elif osp.isdir(entry.path): + if list_dir: + rel_dir = osp.relpath(entry.path, root) + yield rel_dir + if recursive: + yield from _list_dir_or_file(entry.path, list_dir, list_file, suffix, recursive) + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) diff --git a/cosmos_transfer1/utils/easy_io/backends/registry_utils.py b/cosmos_transfer1/utils/easy_io/backends/registry_utils.py new file mode 100644 index 00000000..acd77b13 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/backends/registry_utils.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Optional, Type, Union + +from cosmos_transfer1.utils.easy_io.backends.base_backend import BaseStorageBackend +from cosmos_transfer1.utils.easy_io.backends.http_backend import HTTPBackend +from cosmos_transfer1.utils.easy_io.backends.local_backend import LocalBackend + +backends: dict = {} +prefix_to_backends: dict = {} + + +def _register_backend( + name: str, + backend: Type[BaseStorageBackend], + force: bool = False, + prefixes: Union[str, list, tuple, None] = None, +): + """Register a backend. + + Args: + name (str): The name of the registered backend. + backend (BaseStorageBackend): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + force (bool): Whether to override the backend if the name has already + been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefix + of the registered storage backend. Defaults to None. + """ + global backends, prefix_to_backends + + if not isinstance(name, str): + raise TypeError("the backend name should be a string, " f"but got {type(name)}") + + if not inspect.isclass(backend): + raise TypeError(f"backend should be a class, but got {type(backend)}") + if not issubclass(backend, BaseStorageBackend): + raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend") + + if name in backends and not force: + raise ValueError( + f"{name} is already registered as a storage backend, " 'add "force=True" if you want to override it' + ) + backends[name] = backend + + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + + for prefix in prefixes: + if prefix in prefix_to_backends and not force: + raise ValueError( + f"{prefix} is already registered as a storage backend," + ' add "force=True" if you want to override it' + ) + + prefix_to_backends[prefix] = backend + + +def register_backend( + name: str, + backend: Optional[Type[BaseStorageBackend]] = None, + force: bool = False, + prefixes: Union[str, list, tuple, None] = None, +): + """Register a backend. + + Args: + name (str): The name of the registered backend. + backend (class, optional): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + When this method is used as a decorator, backend is None. + Defaults to None. + force (bool): Whether to override the backend if the name has already + been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefix + of the registered storage backend. Defaults to None. + + This method can be used as a normal method or a decorator. + + Examples: + + >>> class NewBackend(BaseStorageBackend): + ... def get(self, filepath): + ... return filepath + ... + ... def get_text(self, filepath): + ... return filepath + >>> register_backend('new', NewBackend) + + >>> @register_backend('new') + ... class NewBackend(BaseStorageBackend): + ... def get(self, filepath): + ... return filepath + ... + ... def get_text(self, filepath): + ... return filepath + """ + if backend is not None: + _register_backend(name, backend, force=force, prefixes=prefixes) + return + + def _register(backend_cls): + _register_backend(name, backend_cls, force=force, prefixes=prefixes) + return backend_cls + + return _register + + +register_backend("local", LocalBackend, prefixes="") +register_backend("http", HTTPBackend, prefixes=["http", "https"]) diff --git a/cosmos_transfer1/utils/easy_io/easy_io.py b/cosmos_transfer1/utils/easy_io/easy_io.py new file mode 100644 index 00000000..cb2959aa --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/easy_io.py @@ -0,0 +1,1066 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import warnings +from contextlib import contextmanager +from io import BytesIO, StringIO +from pathlib import Path +from typing import IO, Any, Generator, Iterator, Optional, Tuple, Union + +from cosmos_transfer1.utils.easy_io.backends import backends, prefix_to_backends +from cosmos_transfer1.utils.easy_io.file_client import FileClient +from cosmos_transfer1.utils.easy_io.handlers import file_handlers + +backend_instances: dict = {} + + +def is_filepath(filepath): + return isinstance(filepath, (str, Path)) + + +def _parse_uri_prefix(uri: Union[str, Path]) -> str: + """Parse the prefix of uri. + + Args: + uri (str or Path): Uri to be parsed that contains the file prefix. + + Examples: + >>> _parse_uri_prefix('/home/path/of/your/file') + '' + >>> _parse_uri_prefix('s3://path/of/your/file') + 's3' + >>> _parse_uri_prefix('clusterName:s3://path/of/your/file') + 's3' + + Returns: + str: Return the prefix of uri if the uri contains '://'. Otherwise, + return ''. + """ + assert is_filepath(uri) + uri = str(uri) + # if uri does not contains '://', the uri will be handled by + # LocalBackend by default + if "://" not in uri: + return "" + else: + prefix, _ = uri.split("://") + if ":" in prefix: + _, prefix = prefix.split(":") + return prefix + + +def _get_file_backend(prefix: str, backend_args: dict): + """Return a file backend based on the prefix or backend_args. + + Args: + prefix (str): Prefix of uri. + backend_args (dict): Arguments to instantiate the corresponding + backend. + """ + # backend name has a higher priority + if "backend" in backend_args: + # backend_args should not be modified + backend_args_bak = backend_args.copy() + backend_name = backend_args_bak.pop("backend") + backend = backends[backend_name](**backend_args_bak) + else: + backend = prefix_to_backends[prefix](**backend_args) + return backend + + +def get_file_backend( + uri: Union[str, Path, None] = None, + *, + backend_args: Optional[dict] = None, + enable_singleton: bool = False, + backend_key: Optional[str] = None, +): + """Return a file backend based on the prefix of uri or backend_args. + + Args: + uri (str or Path): Uri to be parsed that contains the file prefix. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + enable_singleton (bool): Whether to enable the singleton pattern. + If it is True, the backend created will be reused if the + signature is same with the previous one. Defaults to False. + backend_key: str: The key to register the backend. Defaults to None. + + Returns: + BaseStorageBackend: Instantiated Backend object. + + Examples: + >>> # get file backend based on the prefix of uri + >>> uri = 's3://path/of/your/file' + >>> backend = get_file_backend(uri) + >>> # get file backend based on the backend_args + >>> backend = get_file_backend(backend_args={'backend': 's3'}) + >>> # backend name has a higher priority if 'backend' in backend_args + >>> backend = get_file_backend(uri, backend_args={'backend': 's3'}) + """ + global backend_instances + if backend_key is not None: + if backend_key in backend_instances: + return backend_instances[backend_key] + + if backend_args is None: + backend_args = {} + + if uri is None and "backend" not in backend_args and backend_key is None: + raise ValueError( + 'uri should not be None when "backend" does not exist in ' "backend_args and backend_key is None" + ) + + if uri is not None: + prefix = _parse_uri_prefix(uri) + else: + prefix = "" + + if enable_singleton: + unique_key = f"{prefix}:{json.dumps(backend_args)}" + if unique_key in backend_instances: + return backend_instances[unique_key] + + backend = _get_file_backend(prefix, backend_args) + backend_instances[unique_key] = backend + if backend_key is not None: + backend_instances[backend_key] = backend + return backend + else: + backend = _get_file_backend(prefix, backend_args) + return backend + + +def get( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bytes: + """Read bytes from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> filepath = '/path/of/file' + >>> get(filepath) + b'hello world' + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.get(filepath) + + +def get_text( + filepath: Union[str, Path], + encoding="utf-8", + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> str: + """Read text from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> filepath = '/path/of/file' + >>> get_text(filepath) + 'hello world' + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.get_text(filepath, encoding) + + +def put( + obj: bytes, + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> None: + """Write bytes to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Examples: + >>> filepath = '/path/of/file' + >>> put(b'hello world', filepath) + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + backend.put(obj, filepath) + + +def put_text( + obj: str, + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> None: + """Write text to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to open the + ``filepath``. Defaults to 'utf-8'. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Examples: + >>> filepath = '/path/of/file' + >>> put_text('hello world', filepath) + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + backend.put_text(obj, filepath) + + +def exists( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + + Examples: + >>> filepath = '/path/of/file' + >>> exists(filepath) + True + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.exists(filepath) + + +def isdir( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + + Examples: + >>> filepath = '/path/of/dir' + >>> isdir(filepath) + True + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.isdir(filepath) + + +def isfile( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + + Examples: + >>> filepath = '/path/of/file' + >>> isfile(filepath) + True + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.isfile(filepath) + + +def join_path( + filepath: Union[str, Path], + *filepaths: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + r"""Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of \*filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + *filepaths (str or Path): Other paths to be concatenated. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + str: The result of concatenation. + + Examples: + >>> filepath1 = '/path/of/dir1' + >>> filepath2 = 'dir2' + >>> filepath3 = 'path/of/file' + >>> join_path(filepath1, filepath2, filepath3) + '/path/of/dir/dir2/path/of/file' + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.join_path(filepath, *filepaths) + + +@contextmanager +def get_local_path( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Generator[Union[str, Path], None, None]: + """Download data from ``filepath`` and write the data to local path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Note: + If the ``filepath`` is a local path, just return itself and it will + not be released (removed). + + Args: + filepath (str or Path): Path to be read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Yields: + Iterable[str]: Only yield one path. + + Examples: + >>> with get_local_path('abc/def.jpg') as path: + ... # do something here + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + with backend.get_local_path(str(filepath)) as local_path: + yield local_path + + +def copyfile( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Copy a file src to dst and return the destination file. + + src and dst should have the same prefix. If dst specifies a directory, + the file will be copied into dst using the base filename from src. If + dst specifies a file that already exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination file. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError will + be raised. + + Examples: + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> copyfile(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to '/path1/of/dir/file' + >>> copyfile(src, dst) + '/path1/of/dir/file' + """ + backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copyfile(src, dst) + + +def copytree( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a directory + named dst and return the destination directory. + + src and dst should have the same prefix and dst must not already exist. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + str: The destination directory. + + Raises: + FileExistsError: If dst had already existed, a FileExistsError will be + raised. + + Examples: + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> copytree(src, dst) + '/path/of/dir2' + """ + backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copytree(src, dst) + + +def copyfile_from_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Copy a local file src to dst and return the destination file. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copyfile`. + + Args: + src (str or Path): A local file to be copied. + dst (str or Path): Copy file to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = 's3://openmmlab/mmengine/file1' + >>> # src will be copied to 's3://openmmlab/mmengine/file1' + >>> copyfile_from_local(src, dst) + s3://openmmlab/mmengine/file1 + + >>> # dst is a directory + >>> dst = 's3://openmmlab/mmengine' + >>> # src will be copied to 's3://openmmlab/mmengine/file'' + >>> copyfile_from_local(src, dst) + 's3://openmmlab/mmengine/file' + """ + backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copyfile_from_local(src, dst) + + +def copytree_from_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a directory + named dst and return the destination directory. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copytree`. + + Args: + src (str or Path): A local directory to be copied. + dst (str or Path): Copy directory to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> src = '/path/of/dir' + >>> dst = 's3://openmmlab/mmengine/dir' + >>> copyfile_from_local(src, dst) + 's3://openmmlab/mmengine/dir' + """ + backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copytree_from_local(src, dst) + + +def copyfile_to_local( + src: Union[str, Path], + dst: Union[str, Path], + dst_type: str, # Choose from ["file", "dir"] + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Copy the file src to local dst and return the destination file. + + If dst specifies a directory, the file will be copied into dst using + the base filename from src. If dst specifies a file that already + exists, it will be replaced. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copyfile`. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to to local dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> # dst is a file + >>> src = 's3://openmmlab/mmengine/file' + >>> dst = '/path/of/file' + >>> # src will be copied to '/path/of/file' + >>> copyfile_to_local(src, dst) + '/path/of/file' + + >>> # dst is a directory + >>> dst = '/path/of/dir' + >>> # src will be copied to '/path/of/dir/file' + >>> copyfile_to_local(src, dst) + '/path/of/dir/file' + """ + assert dst_type in ["file", "dir"] + Path(dst).parent.mkdir(parents=True, exist_ok=True) + backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copyfile_to_local(src, dst, dst_type=dst_type) + + +def copytree_to_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a local + directory named dst and return the destination directory. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copytree`. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to local dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> src = 's3://openmmlab/mmengine/dir' + >>> dst = '/path/of/dir' + >>> copytree_to_local(src, dst) + '/path/of/dir' + """ + Path(dst).parent.mkdir(parents=True, exist_ok=True) + backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copytree_to_local(src, dst) + + +def remove( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> None: + """Remove a file. + + Args: + filepath (str, Path): Path to be removed. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Raises: + FileNotFoundError: If filepath does not exist, an FileNotFoundError + will be raised. + IsADirectoryError: If filepath is a directory, an IsADirectoryError + will be raised. + + Examples: + >>> filepath = '/path/of/file' + >>> remove(filepath) + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + backend.remove(filepath) + + +def rmtree( + dir_path: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> None: + """Recursively delete a directory tree. + + Args: + dir_path (str or Path): A directory to be removed. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Examples: + >>> dir_path = '/path/of/dir' + >>> rmtree(dir_path) + """ + backend = get_file_backend( + dir_path, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + backend.rmtree(dir_path) + + +def copy_if_symlink_fails( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bool: + """Create a symbolic link pointing to src named dst. + + If failed to create a symbolic link pointing to src, directory copy src to + dst instead. + + Args: + src (str or Path): Create a symbolic link pointing to src. + dst (str or Path): Create a symbolic link named dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + bool: Return True if successfully create a symbolic link pointing to + src. Otherwise, return False. + + Examples: + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> copy_if_symlink_fails(src, dst) + True + >>> src = '/path/of/dir' + >>> dst = '/path1/of/dir1' + >>> copy_if_symlink_fails(src, dst) + True + """ + backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copy_if_symlink_fails(src, dst) + + +def list_dir( + dir_path: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +): + """List all folders in an S3 bucket with a given prefix. + + Args: + dir_path (str | Path): Path of the directory. + + Examples: + >>> dir_path = '/path/of/dir' + >>> for file_path in list_dir(dir_path): + ... print(file_path) + """ + if not dir_path.endswith("/"): + dir_path += "/" + backend = get_file_backend( + dir_path, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + + return backend.list_dir(dir_path) + + +def list_dir_or_file( + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str or Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix that we are + interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the directory. + Defaults to False. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + + Examples: + >>> dir_path = '/path/of/dir' + >>> for file_path in list_dir_or_file(dir_path): + ... print(file_path) + >>> # list those files and directories in current directory + >>> for file_path in list_dir_or_file(dir_path): + ... print(file_path) + >>> # only list files + >>> for file_path in list_dir_or_file(dir_path, list_dir=False): + ... print(file_path) + >>> # only list directories + >>> for file_path in list_dir_or_file(dir_path, list_file=False): + ... print(file_path) + >>> # only list files ending with specified suffixes + >>> for file_path in list_dir_or_file(dir_path, suffix='.txt'): + ... print(file_path) + >>> # list all files and directory recursively + >>> for file_path in list_dir_or_file(dir_path, recursive=True): + ... print(file_path) + """ + backend = get_file_backend( + dir_path, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) + + +def generate_presigned_url( + url: str, + client_method: str = "get_object", + expires_in: int = 3600, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> str: + """Generate the presigned url of video stream which can be passed to + mmcv.VideoReader. Now only work on s3 backend. + + Note: + Now only work on s3 backend. + + Args: + url (str): Url of video stream. + client_method (str): Method of client, 'get_object' or + 'put_object'. Defaults to 'get_object'. + expires_in (int): expires, in seconds. Defaults to 3600. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: Generated presigned url. + """ + backend = get_file_backend(url, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.generate_presigned_url(url, client_method, expires_in) + + +def load( + file: Union[str, Path, IO[Any]], + file_format: Optional[str] = None, + file_client_args: Optional[dict] = None, + fast_backend: bool = False, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, + **kwargs, +): + """Load data from json/yaml/pickle files. + + This method provides a unified api for loading data from serialized files. + + ``load`` supports loading data from serialized files those can be storaged + in different backends. + + Args: + file (str or :obj:`Path` or file-like object): Filename or a file-like + object. + file_format (str, optional): If not specified, the file format will be + inferred from the file extension, otherwise use the specified one. + Currently supported formats include "json", "yaml/yml" and + "pickle/pkl". + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmengine.fileio.FileClient` for details. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + fast_backend: bool: Whether to use multiprocess. Defaults to False. + backend_args (dict, optional): Arguments to instantiate the + prefix of uri corresponding backend. Defaults to None. + New in v0.2.0. + + Examples: + >>> load('/path/of/your/file') # file is storaged in disk + >>> load('https://path/of/your/file') # file is storaged in Internet + >>> load('s3://path/of/your/file') # file is storaged in s3 + + Returns: + The content from the file. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None and isinstance(file, str): + file_format = file.split(".")[-1] + # convert file_format to lower case + file_format = file_format.lower() + if file_format not in file_handlers: + raise TypeError(f"Unsupported format: {file_format}") + + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' 'Please use "backend_args" instead', + DeprecationWarning, + ) + if backend_args is not None: + raise ValueError('"file_client_args and "backend_args" cannot be set at the ' "same time.") + + handler = file_handlers[file_format] + if isinstance(file, str): + if file_client_args is not None: + file_client = FileClient.infer_client(file_client_args, file) + file_backend = file_client + else: + file_backend = get_file_backend( + file, + backend_args=backend_args, + backend_key=backend_key, + enable_singleton=True, + ) + + if handler.str_like: + with StringIO(file_backend.get_text(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + if fast_backend: + if hasattr(file_backend, "fast_get"): + with BytesIO(file_backend.fast_get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + warnings.warn( + f"fast_backend is not supported by the backend, type {type(file_backend)} fallback to normal get" + ) + with BytesIO(file_backend.get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + with BytesIO(file_backend.get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + elif hasattr(file, "read"): + obj = handler.load_from_fileobj(file, **kwargs) + else: + raise TypeError('"file" must be a filepath str or a file-object') + return obj + + +def dump( + obj: Any, + file: Union[str, Path, IO[Any], None] = None, + file_format: Optional[str] = None, + file_client_args: Optional[dict] = None, + fast_backend: bool = False, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, + **kwargs, +): + """Dump data to json/yaml/pickle strings or files. + + This method provides a unified api for dumping data as strings or to files, + and also supports custom arguments for each file format. + + ``dump`` supports dumping data as strings or to files which is saved to + different backends. + + Args: + obj (any): The python object to be dumped. + file (str or :obj:`Path` or file-like object, optional): If not + specified, then the object is dumped to a str, otherwise to a file + specified by the filename or file-like object. + file_format (str, optional): Same as :func:`load`. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmengine.fileio.FileClient` for details. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + fast_backend: bool: Whether to use multiprocess. Defaults to False. + backend_args (dict, optional): Arguments to instantiate the + prefix of uri corresponding backend. Defaults to None. + New in v0.2.0. + backend_key: str: The key to register the backend. Defaults to None. + + Examples: + >>> dump('hello world', '/path/of/your/file') # disk + >>> dump('hello world', 's3://path/of/your/file') # ceph or s3 + + Returns: + bool: True for success, False otherwise. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None: + if isinstance(file, str): + file_format = file.split(".")[-1] + elif file is None: + raise ValueError("file_format must be specified since file is None") + # convert file_format to lower case + file_format = file_format.lower() + if file_format not in file_handlers: + raise TypeError(f"Unsupported format: {file_format}") + + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' 'Please use "backend_args" instead', + DeprecationWarning, + ) + if backend_args is not None: + raise ValueError('"file_client_args" and "backend_args" cannot be set at the ' "same time.") + + handler = file_handlers[file_format] + if file is None: + return handler.dump_to_str(obj, **kwargs) + elif isinstance(file, str): + if file_client_args is not None: + file_client = FileClient.infer_client(file_client_args, file) + file_backend = file_client + else: + file_backend = get_file_backend( + file, + backend_args=backend_args, + backend_key=backend_key, + enable_singleton=True, + ) + + if handler.str_like: + with StringIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + file_backend.put_text(f.getvalue(), file) + else: + with BytesIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + if fast_backend: + if hasattr(file_backend, "fast_put"): + file_backend.fast_put(f, file) + else: + warnings.warn("fast_backend is not supported by the backend, fallback to normal put") + file_backend.put(f, file) + else: + file_backend.put(f, file) + elif hasattr(file, "write"): + handler.dump_to_fileobj(obj, file, **kwargs) + else: + raise TypeError('"file" must be a filename str or a file-object') diff --git a/cosmos_transfer1/utils/easy_io/file_client.py b/cosmos_transfer1/utils/easy_io/file_client.py new file mode 100644 index 00000000..be8a378e --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/file_client.py @@ -0,0 +1,450 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Generator, Iterator, Optional, Tuple, Union + +from cosmos_transfer1.utils.easy_io.backends import BaseStorageBackend, HTTPBackend, LocalBackend + + +def is_filepath(filepath): + return isinstance(filepath, (str, Path)) + + +class HardDiskBackend(LocalBackend): + """Raw hard disks storage backend.""" + + @property + def name(self): + return self.__class__.__name__ + + +class FileClient: + """A general file client to access files in different backends. + + The client loads a file or text in a specified backend from its path + and returns it as a binary or text file. There are two ways to choose a + backend, the name of backend and the prefix of path. Although both of them + can be used to choose a storage backend, ``backend`` has a higher priority + that is if they are all set, the storage backend will be chosen by the + backend argument. If they are all `None`, the disk backend will be chosen. + Note that It can also register other backend accessor with a given name, + prefixes, and backend class. In addition, We use the singleton pattern to + avoid repeated object creation. If the arguments are the same, the same + object will be returned. + + Warning: + `FileClient` will be deprecated in future. Please use io functions + in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io + + Args: + backend (str, optional): The storage backend type. Options are "disk", + "memcached", "lmdb", "http" and "s3". Defaults to None. + prefix (str, optional): The prefix of the registered storage backend. + Options are "s3", "http", "https". Defaults to None. + + Examples: + >>> # only set backend + >>> file_client = FileClient(backend='s3') + >>> # only set prefix + >>> file_client = FileClient(prefix='s3') + >>> # set both backend and prefix but use backend to choose client + >>> file_client = FileClient(backend='s3', prefix='s3') + >>> # if the arguments are the same, the same object is returned + >>> file_client1 = FileClient(backend='s3') + >>> file_client1 is file_client + True + + Attributes: + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + "disk": HardDiskBackend, + "http": HTTPBackend, + } + + _prefix_to_backends: dict = { + "http": HTTPBackend, + "https": HTTPBackend, + } + + _instances: dict = {} + + client: Any + + def __new__(cls, backend=None, prefix=None, **kwargs): + if backend is None and prefix is None: + backend = "disk" + if backend is not None and backend not in cls._backends: + raise ValueError( + f"Backend {backend} is not supported. Currently supported ones" f" are {list(cls._backends.keys())}" + ) + if prefix is not None and prefix not in cls._prefix_to_backends: + raise ValueError( + f"prefix {prefix} is not supported. Currently supported ones " + f"are {list(cls._prefix_to_backends.keys())}" + ) + + # concatenate the arguments to a unique key for determining whether + # objects with the same arguments were created + arg_key = f"{backend}:{prefix}" + for key, value in kwargs.items(): + arg_key += f":{key}:{value}" + + # if a backend was overridden, it will create a new object + if arg_key in cls._instances: + _instance = cls._instances[arg_key] + else: + # create a new object and put it to _instance + _instance = super().__new__(cls) + if backend is not None: + _instance.client = cls._backends[backend](**kwargs) + else: + _instance.client = cls._prefix_to_backends[prefix](**kwargs) + + cls._instances[arg_key] = _instance + + return _instance + + @property + def name(self): + return self.client.name + + @property + def allow_symlink(self): + return self.client.allow_symlink + + @staticmethod + def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: + """Parse the prefix of a uri. + + Args: + uri (str | Path): Uri to be parsed that contains the file prefix. + + Examples: + >>> FileClient.parse_uri_prefix('s3://path/of/your/file') + 's3' + + Returns: + str | None: Return the prefix of uri if the uri contains '://' else + ``None``. + """ + assert is_filepath(uri) + uri = str(uri) + if "://" not in uri: + return None + else: + prefix, _ = uri.split("://") + if ":" in prefix: + _, prefix = prefix.split(":") + return prefix + + @classmethod + def infer_client( + cls, + file_client_args: Optional[dict] = None, + uri: Optional[Union[str, Path]] = None, + ) -> "FileClient": + """Infer a suitable file client based on the URI and arguments. + + Args: + file_client_args (dict, optional): Arguments to instantiate a + FileClient. Defaults to None. + uri (str | Path, optional): Uri to be parsed that contains the file + prefix. Defaults to None. + + Examples: + >>> uri = 's3://path/of/your/file' + >>> file_client = FileClient.infer_client(uri=uri) + >>> file_client_args = {'backend': 's3'} + >>> file_client = FileClient.infer_client(file_client_args) + + Returns: + FileClient: Instantiated FileClient object. + """ + assert file_client_args is not None or uri is not None + if file_client_args is None: + file_prefix = cls.parse_uri_prefix(uri) # type: ignore + return cls(prefix=file_prefix) + else: + return cls(**file_client_args) + + @classmethod + def _register_backend(cls, name, backend, force=False, prefixes=None): + if not isinstance(name, str): + raise TypeError("the backend name should be a string, " f"but got {type(name)}") + if not inspect.isclass(backend): + raise TypeError(f"backend should be a class but got {type(backend)}") + if not issubclass(backend, BaseStorageBackend): + raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend") + if not force and name in cls._backends: + raise KeyError( + f"{name} is already registered as a storage backend, " 'add "force=True" if you want to override it' + ) + + if name in cls._backends and force: + for arg_key, instance in list(cls._instances.items()): + if isinstance(instance.client, cls._backends[name]): + cls._instances.pop(arg_key) + cls._backends[name] = backend + + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + for prefix in prefixes: + if prefix not in cls._prefix_to_backends: + cls._prefix_to_backends[prefix] = backend + elif (prefix in cls._prefix_to_backends) and force: + overridden_backend = cls._prefix_to_backends[prefix] + for arg_key, instance in list(cls._instances.items()): + if isinstance(instance.client, overridden_backend): + cls._instances.pop(arg_key) + else: + raise KeyError( + f"{prefix} is already registered as a storage backend," + ' add "force=True" if you want to override it' + ) + + @classmethod + def register_backend(cls, name, backend=None, force=False, prefixes=None): + """Register a backend to FileClient. + + This method can be used as a normal class method or a decorator. + + .. code-block:: python + + class NewBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + FileClient.register_backend('new', NewBackend) + + or + + .. code-block:: python + + @FileClient.register_backend('new') + class NewBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + Args: + name (str): The name of the registered backend. + backend (class, optional): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + When this method is used as a decorator, backend is None. + Defaults to None. + force (bool, optional): Whether to override the backend if the name + has already been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefixes + of the registered storage backend. Defaults to None. + `New in version 1.3.15.` + """ + if backend is not None: + cls._register_backend(name, backend, force=force, prefixes=prefixes) + return + + def _register(backend_cls): + cls._register_backend(name, backend_cls, force=force, prefixes=prefixes) + return backend_cls + + return _register + + def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: + """Read data from a given ``filepath`` with 'rb' mode. + + Note: + There are two types of return values for ``get``, one is ``bytes`` + and the other is ``memoryview``. The advantage of using memoryview + is that you can avoid copying, and if you want to convert it to + ``bytes``, you can use ``.tobytes()``. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes | memoryview: Expected bytes object or a memory view of the + bytes object. + """ + return self.client.get(filepath) + + def get_text(self, filepath: Union[str, Path], encoding="utf-8") -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + return self.client.get_text(filepath, encoding) + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` should create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + self.client.put(obj, filepath) + + def put_text(self, obj: str, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to open the + `filepath`. Defaults to 'utf-8'. + """ + self.client.put_text(obj, filepath) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str, Path): Path to be removed. + """ + self.client.remove(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + return self.client.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + return self.client.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + return self.client.isfile(filepath) + + def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: + r"""Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of \*filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + """ + return self.client.join_path(filepath, *filepaths) + + @contextmanager + def get_local_path(self, filepath: Union[str, Path]) -> Generator[Union[str, Path], None, None]: + """Download data from ``filepath`` and write the data to local path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Note: + If the ``filepath`` is a local path, just return itself. + + .. warning:: + ``get_local_path`` is an experimental interface that may change in + the future. + + Args: + filepath (str or Path): Path to be read data. + + Examples: + >>> file_client = FileClient(prefix='s3') + >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path: + ... # do something here + + Yields: + Iterable[str]: Only yield one path. + """ + with self.client.get_local_path(str(filepath)) as local_path: + yield local_path + + def list_dir_or_file( # pylint: disable=too-many-arguments + self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False, + ) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the + directory. Defaults to False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) diff --git a/cosmos_transfer1/utils/easy_io/handlers/__init__.py b/cosmos_transfer1/utils/easy_io/handlers/__init__.py new file mode 100644 index 00000000..aafac064 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler +from cosmos_transfer1.utils.easy_io.handlers.json_handler import JsonHandler +from cosmos_transfer1.utils.easy_io.handlers.pickle_handler import PickleHandler +from cosmos_transfer1.utils.easy_io.handlers.registry_utils import file_handlers, register_handler +from cosmos_transfer1.utils.easy_io.handlers.yaml_handler import YamlHandler + +__all__ = [ + "BaseFileHandler", + "JsonHandler", + "PickleHandler", + "YamlHandler", + "register_handler", + "file_handlers", +] diff --git a/cosmos_transfer1/utils/easy_io/handlers/base.py b/cosmos_transfer1/utils/easy_io/handlers/base.py new file mode 100644 index 00000000..5e5dcbca --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/base.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABCMeta, abstractmethod + + +class BaseFileHandler(metaclass=ABCMeta): + # `str_like` is a flag to indicate whether the type of file object is + # str-like object or bytes-like object. Pickle only processes bytes-like + # objects but json only processes str-like object. If it is str-like + # object, `StringIO` will be used to process the buffer. + str_like = True + + @abstractmethod + def load_from_fileobj(self, file, **kwargs): + pass + + @abstractmethod + def dump_to_fileobj(self, obj, file, **kwargs): + pass + + @abstractmethod + def dump_to_str(self, obj, **kwargs): + pass + + def load_from_path(self, filepath, mode="r", **kwargs): + with open(filepath, mode) as f: + return self.load_from_fileobj(f, **kwargs) + + def dump_to_path(self, obj, filepath, mode="w", **kwargs): + with open(filepath, mode) as f: + self.dump_to_fileobj(obj, f, **kwargs) diff --git a/cosmos_transfer1/utils/easy_io/handlers/csv_handler.py b/cosmos_transfer1/utils/easy_io/handlers/csv_handler.py new file mode 100644 index 00000000..c76294ca --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/csv_handler.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +from io import StringIO + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler + + +class CsvHandler(BaseFileHandler): + def load_from_fileobj(self, file, **kwargs): + del kwargs + reader = csv.reader(file) + return list(reader) + + def dump_to_fileobj(self, obj, file, **kwargs): + del kwargs + writer = csv.writer(file) + if not all(isinstance(row, list) for row in obj): + raise ValueError("Each row must be a list") + writer.writerows(obj) + + def dump_to_str(self, obj, **kwargs): + del kwargs + output = StringIO() + writer = csv.writer(output) + if not all(isinstance(row, list) for row in obj): + raise ValueError("Each row must be a list") + writer.writerows(obj) + return output.getvalue() diff --git a/cosmos_transfer1/utils/easy_io/handlers/gzip_handler.py b/cosmos_transfer1/utils/easy_io/handlers/gzip_handler.py new file mode 100644 index 00000000..2e063a73 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/gzip_handler.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gzip +import pickle +from io import BytesIO +from typing import Any + +from cosmos_transfer1.utils.easy_io.handlers.pickle_handler import PickleHandler + + +class GzipHandler(PickleHandler): + str_like = False + + def load_from_fileobj(self, file: BytesIO, **kwargs): + with gzip.GzipFile(fileobj=file, mode="rb") as f: + return pickle.load(f) + + def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs): + with gzip.GzipFile(fileobj=file, mode="wb") as f: + pickle.dump(obj, f) diff --git a/cosmos_transfer1/utils/easy_io/handlers/imageio_video_handler.py b/cosmos_transfer1/utils/easy_io/handlers/imageio_video_handler.py new file mode 100644 index 00000000..67dbbc27 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/imageio_video_handler.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import IO + +import numpy as np +import torch + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler + +try: + import imageio +except ImportError: + imageio = None + + +class ImageioVideoHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file: IO[bytes], format: str = "mp4", mode: str = "rgb", **kwargs): + """ + Load video from a file-like object using imageio with specified format and color mode. + + Parameters: + file (IO[bytes]): A file-like object containing video data. + format (str): Format of the video file (default 'mp4'). + mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb'). + + Returns: + tuple: A tuple containing an array of video frames and metadata about the video. + """ + file.seek(0) + video_reader = imageio.get_reader(file, format, **kwargs) + + video_frames = [] + for frame in video_reader: + if mode == "gray": + import cv2 # Convert frame to grayscale if mode is gray + + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + frame = np.expand_dims(frame, axis=2) # Keep frame dimensions consistent + video_frames.append(frame) + + return np.array(video_frames), video_reader.get_meta_data() + + def dump_to_fileobj( + self, + obj: np.ndarray | torch.Tensor, + file: IO[bytes], + format: str = "mp4", # pylint: disable=redefined-builtin + fps: int = 17, + quality: int = 5, + **kwargs, + ): + """ + Save an array of video frames to a file-like object using imageio. + + Parameters: + obj (np.ndarray): An array of frames to be saved as video. + file (IO[bytes]): A file-like object to which the video data will be written. + format (str): Format of the video file (default 'mp4'). + fps (int): Frames per second of the output video (default 30). + + """ + if isinstance(obj, torch.Tensor): + assert obj.dtype == torch.uint8 + obj = obj.cpu().numpy() + h, w = obj.shape[1:-1] + kwargs = { + "fps": fps, + "quality": quality, + "macro_block_size": 1, + "ffmpeg_params": ["-s", f"{w}x{h}"], + "output_params": ["-f", "mp4"], + } + imageio.mimsave(file, obj, format, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_transfer1/utils/easy_io/handlers/json_handler.py b/cosmos_transfer1/utils/easy_io/handlers/json_handler.py new file mode 100644 index 00000000..beb55c61 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/json_handler.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import numpy as np + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler + + +def set_default(obj): + """Set default json values for non-serializable values. + + It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. + It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, + etc.) into plain numbers of plain python built-in types. + """ + if isinstance(obj, (set, range)): + return list(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.generic): + return obj.item() + raise TypeError(f"{type(obj)} is unsupported for json dump") + + +class JsonHandler(BaseFileHandler): + def load_from_fileobj(self, file): + return json.load(file) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault("default", set_default) + json.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault("default", set_default) + return json.dumps(obj, **kwargs) diff --git a/cosmos_transfer1/utils/easy_io/handlers/jsonl_handler.py b/cosmos_transfer1/utils/easy_io/handlers/jsonl_handler.py new file mode 100644 index 00000000..000ffb19 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/jsonl_handler.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import IO + +import numpy as np + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler + + +def set_default(obj): + """Set default json values for non-serializable values. + + It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. + It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, + etc.) into plain numbers of plain python built-in types. + """ + if isinstance(obj, (set, range)): + return list(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.generic): + return obj.item() + raise TypeError(f"{type(obj)} is unsupported for json dump") + + +class JsonlHandler(BaseFileHandler): + """Handler for JSON lines (JSONL) files.""" + + def load_from_fileobj(self, file: IO[bytes]): + """Load JSON objects from a newline-delimited JSON (JSONL) file object. + + Returns: + A list of Python objects loaded from each JSON line. + """ + data = [] + for line in file: + line = line.strip() + if not line: + continue # skip empty lines if any + data.append(json.loads(line)) + return data + + def dump_to_fileobj(self, obj: IO[bytes], file, **kwargs): + """Dump a list of objects to a newline-delimited JSON (JSONL) file object. + + Args: + obj: A list (or iterable) of objects to dump line by line. + """ + kwargs.setdefault("default", set_default) + for item in obj: + file.write(json.dumps(item, **kwargs) + "\n") + + def dump_to_str(self, obj, **kwargs): + """Dump a list of objects to a newline-delimited JSON (JSONL) string.""" + kwargs.setdefault("default", set_default) + lines = [json.dumps(item, **kwargs) for item in obj] + return "\n".join(lines) + + +if __name__ == "__main__": + from cosmos_transfer1.utils.easy_io import easy_io + + easy_io.dump([1, 2, 3], "test.jsonl", file_format="jsonl") + print(easy_io.load("test.jsonl")) + easy_io.dump([{"key1": 1, "key2": 2}, {"key1": 3, "key2": 4}], "test.jsonl", file_format="jsonl") + print(easy_io.load("test.jsonl")) diff --git a/cosmos_transfer1/utils/easy_io/handlers/np_handler.py b/cosmos_transfer1/utils/easy_io/handlers/np_handler.py new file mode 100644 index 00000000..070396a1 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/np_handler.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import BytesIO +from typing import IO, Any + +import numpy as np + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler + + +class NumpyHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file: IO[bytes], **kwargs) -> Any: + """ + Load a NumPy array from a file-like object. + + Parameters: + file (IO[bytes]): The file-like object containing the NumPy array data. + **kwargs: Additional keyword arguments passed to `np.load`. + + Returns: + numpy.ndarray: The loaded NumPy array. + """ + return np.load(file, **kwargs) + + def load_from_path(self, filepath: str, **kwargs) -> Any: + """ + Load a NumPy array from a file path. + + Parameters: + filepath (str): The path to the file to load. + **kwargs: Additional keyword arguments passed to `np.load`. + + Returns: + numpy.ndarray: The loaded NumPy array. + """ + return super().load_from_path(filepath, mode="rb", **kwargs) + + def dump_to_str(self, obj: np.ndarray, **kwargs) -> str: + """ + Serialize a NumPy array to a string in binary format. + + Parameters: + obj (np.ndarray): The NumPy array to serialize. + **kwargs: Additional keyword arguments passed to `np.save`. + + Returns: + str: The serialized NumPy array as a string. + """ + with BytesIO() as f: + np.save(f, obj, **kwargs) + return f.getvalue() + + def dump_to_fileobj(self, obj: np.ndarray, file: IO[bytes], **kwargs): + """ + Dump a NumPy array to a file-like object. + + Parameters: + obj (np.ndarray): The NumPy array to dump. + file (IO[bytes]): The file-like object to which the array is dumped. + **kwargs: Additional keyword arguments passed to `np.save`. + """ + np.save(file, obj, **kwargs) + + def dump_to_path(self, obj: np.ndarray, filepath: str, **kwargs): + """ + Dump a NumPy array to a file path. + + Parameters: + obj (np.ndarray): The NumPy array to dump. + filepath (str): The file path where the array should be saved. + **kwargs: Additional keyword arguments passed to `np.save`. + """ + with open(filepath, "wb") as f: + np.save(f, obj, **kwargs) diff --git a/cosmos_transfer1/utils/easy_io/handlers/pandas_handler.py b/cosmos_transfer1/utils/easy_io/handlers/pandas_handler.py new file mode 100644 index 00000000..3389cfc8 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/pandas_handler.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler # isort:skip + + +class PandasHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file, **kwargs): + return pd.read_csv(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + obj.to_csv(file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError("PandasHandler does not support dumping to str") diff --git a/cosmos_transfer1/utils/easy_io/handlers/pickle_handler.py b/cosmos_transfer1/utils/easy_io/handlers/pickle_handler.py new file mode 100644 index 00000000..618e750a --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/pickle_handler.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +from io import BytesIO +from typing import Any + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler + + +class PickleHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file: BytesIO, **kwargs): + return pickle.load(file, **kwargs) + + def load_from_path(self, filepath, **kwargs): + return super().load_from_path(filepath, mode="rb", **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault("protocol", 2) + return pickle.dumps(obj, **kwargs) + + def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs): + kwargs.setdefault("protocol", 2) + pickle.dump(obj, file, **kwargs) + + def dump_to_path(self, obj, filepath, **kwargs): + with open(filepath, "wb") as f: + pickle.dump(obj, f, **kwargs) diff --git a/cosmos_transfer1/utils/easy_io/handlers/pil_handler.py b/cosmos_transfer1/utils/easy_io/handlers/pil_handler.py new file mode 100644 index 00000000..618ca9d2 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/pil_handler.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import IO, Optional, Tuple, Union + +import numpy as np + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler + +try: + from PIL import Image +except ImportError: + Image = None + + +class PILHandler(BaseFileHandler): + format: str + str_like = False + + def load_from_fileobj( + self, + file: IO[bytes], + fmt: str = "pil", + size: Optional[Union[int, Tuple[int, int]]] = None, + **kwargs, + ): + """ + Load an image from a file-like object and return it in a specified format. + + Args: + file (IO[bytes]): A file-like object containing the image data. + fmt (str): The format to convert the image into. Options are \ + 'numpy', 'np', 'npy', 'type' (all return numpy arrays), \ + 'pil' (returns PIL Image), 'th', 'torch' (returns a torch tensor). + size (Optional[Union[int, Tuple[int, int]]]): The new size of the image as a single integer \ + or a tuple of (width, height). If specified, the image is resized accordingly. + **kwargs: Additional keyword arguments that can be passed to conversion functions. + + Returns: + Image data in the format specified by `fmt`. + + Raises: + IOError: If the image cannot be loaded or processed. + ValueError: If the specified format is unsupported. + """ + try: + img = Image.open(file) + img.load() # Explicitly load the image data + if size is not None: + if isinstance(size, int): + size = ( + size, + size, + ) # create a tuple if only one integer is provided + img = img.resize(size, Image.ANTIALIAS) + + # Return the image in the requested format + if fmt in ["numpy", "np", "npy"]: + return np.array(img, **kwargs) + if fmt == "pil": + return img + if fmt in ["th", "torch"]: + import torch + + # Convert to tensor + img_tensor = torch.from_numpy(np.array(img, **kwargs)) + # Convert image from HxWxC to CxHxW + if img_tensor.ndim == 3: + img_tensor = img_tensor.permute(2, 0, 1) + return img_tensor + raise ValueError( + "Unsupported format. Supported formats are 'numpy', 'np', 'npy', 'pil', 'th', and 'torch'." + ) + except Exception as e: + raise IOError(f"Unable to load image: {e}") from e + + def dump_to_fileobj(self, obj, file: IO[bytes], **kwargs): + if "format" not in kwargs: + kwargs["format"] = self.format + kwargs["format"] = "JPEG" if self.format.lower() == "jpg" else self.format.upper() + obj.save(file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_transfer1/utils/easy_io/handlers/registry_utils.py b/cosmos_transfer1/utils/easy_io/handlers/registry_utils.py new file mode 100644 index 00000000..286d0a6e --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/registry_utils.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler +from cosmos_transfer1.utils.easy_io.handlers.csv_handler import CsvHandler +from cosmos_transfer1.utils.easy_io.handlers.gzip_handler import GzipHandler +from cosmos_transfer1.utils.easy_io.handlers.imageio_video_handler import ImageioVideoHandler +from cosmos_transfer1.utils.easy_io.handlers.json_handler import JsonHandler +from cosmos_transfer1.utils.easy_io.handlers.jsonl_handler import JsonlHandler +from cosmos_transfer1.utils.easy_io.handlers.np_handler import NumpyHandler +from cosmos_transfer1.utils.easy_io.handlers.pandas_handler import PandasHandler +from cosmos_transfer1.utils.easy_io.handlers.pickle_handler import PickleHandler +from cosmos_transfer1.utils.easy_io.handlers.pil_handler import PILHandler +from cosmos_transfer1.utils.easy_io.handlers.tarfile_handler import TarHandler +from cosmos_transfer1.utils.easy_io.handlers.torch_handler import TorchHandler +from cosmos_transfer1.utils.easy_io.handlers.torchjit_handler import TorchJitHandler +from cosmos_transfer1.utils.easy_io.handlers.txt_handler import TxtHandler +from cosmos_transfer1.utils.easy_io.handlers.yaml_handler import YamlHandler + +file_handlers = { + "json": JsonHandler(), + "yaml": YamlHandler(), + "yml": YamlHandler(), + "pickle": PickleHandler(), + "pkl": PickleHandler(), + "tar": TarHandler(), + "jit": TorchJitHandler(), + "npy": NumpyHandler(), + "txt": TxtHandler(), + "csv": CsvHandler(), + "pandas": PandasHandler(), + "gz": GzipHandler(), + "jsonl": JsonlHandler(), +} + +for torch_type in ["pt", "pth", "ckpt"]: + file_handlers[torch_type] = TorchHandler() +for img_type in ["jpg", "jpeg", "png", "bmp", "gif"]: + file_handlers[img_type] = PILHandler() + file_handlers[img_type].format = img_type +for video_type in ["mp4", "avi", "mov", "webm", "flv", "wmv"]: + file_handlers[video_type] = ImageioVideoHandler() + + +def _register_handler(handler, file_formats): + """Register a handler for some file extensions. + + Args: + handler (:obj:`BaseFileHandler`): Handler to be registered. + file_formats (str or list[str]): File formats to be handled by this + handler. + """ + if not isinstance(handler, BaseFileHandler): + raise TypeError(f"handler must be a child of BaseFileHandler, not {type(handler)}") + if isinstance(file_formats, str): + file_formats = [file_formats] + if not all([isinstance(item, str) for item in file_formats]): + raise TypeError("file_formats must be a str or a list of str") + for ext in file_formats: + file_handlers[ext] = handler + + +def register_handler(file_formats, **kwargs): + def wrap(cls): + _register_handler(cls(**kwargs), file_formats) + return cls + + return wrap diff --git a/cosmos_transfer1/utils/easy_io/handlers/tarfile_handler.py b/cosmos_transfer1/utils/easy_io/handlers/tarfile_handler.py new file mode 100644 index 00000000..9992569d --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/tarfile_handler.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tarfile + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler + + +class TarHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file, mode="r|*", **kwargs): + return tarfile.open(fileobj=file, mode=mode, **kwargs) + + def load_from_path(self, filepath, mode="r|*", **kwargs): + return tarfile.open(filepath, mode=mode, **kwargs) + + def dump_to_fileobj(self, obj, file, mode="w", **kwargs): + with tarfile.open(fileobj=file, mode=mode) as tar: + tar.add(obj, **kwargs) + + def dump_to_path(self, obj, filepath, mode="w", **kwargs): + with tarfile.open(filepath, mode=mode) as tar: + tar.add(obj, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_transfer1/utils/easy_io/handlers/torch_handler.py b/cosmos_transfer1/utils/easy_io/handlers/torch_handler.py new file mode 100644 index 00000000..71adc6e1 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/torch_handler.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import torch +except ImportError: + torch = None + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler + + +class TorchHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file, **kwargs): + return torch.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + torch.save(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_transfer1/utils/easy_io/handlers/torchjit_handler.py b/cosmos_transfer1/utils/easy_io/handlers/torchjit_handler.py new file mode 100644 index 00000000..6711cddf --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/torchjit_handler.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import torch +except ImportError: + torch = None + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler + + +class TorchJitHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file, **kwargs): + return torch.jit.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + torch.jit.save(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_transfer1/utils/easy_io/handlers/txt_handler.py b/cosmos_transfer1/utils/easy_io/handlers/txt_handler.py new file mode 100644 index 00000000..d42408d7 --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/txt_handler.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler + + +class TxtHandler(BaseFileHandler): + def load_from_fileobj(self, file, **kwargs): + del kwargs + return file.read() + + def dump_to_fileobj(self, obj, file, **kwargs): + del kwargs + if not isinstance(obj, str): + obj = str(obj) + file.write(obj) + + def dump_to_str(self, obj, **kwargs): + del kwargs + if not isinstance(obj, str): + obj = str(obj) + return obj diff --git a/cosmos_transfer1/utils/easy_io/handlers/yaml_handler.py b/cosmos_transfer1/utils/easy_io/handlers/yaml_handler.py new file mode 100644 index 00000000..246c123a --- /dev/null +++ b/cosmos_transfer1/utils/easy_io/handlers/yaml_handler.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import yaml + +try: + from yaml import CDumper as Dumper # type: ignore + from yaml import CLoader as Loader # type: ignore +except ImportError: + from yaml import Loader, Dumper # type: ignore + +from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler # isort:skip + + +class YamlHandler(BaseFileHandler): + def load_from_fileobj(self, file, **kwargs): + kwargs.setdefault("Loader", Loader) + return yaml.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault("Dumper", Dumper) + yaml.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault("Dumper", Dumper) + return yaml.dump(obj, **kwargs) diff --git a/cosmos_transfer1/utils/ema.py b/cosmos_transfer1/utils/ema.py new file mode 100644 index 00000000..e402f65f --- /dev/null +++ b/cosmos_transfer1/utils/ema.py @@ -0,0 +1,327 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union + +import numpy as np +import torch +from megatron.core import parallel_state + +from cosmos_transfer1.utils import distributed, log + +if TYPE_CHECKING: + from cosmos_transfer1.utils.model import Model + + +class FastEmaModelUpdater: + """ + This class is used to update target model~(EMA) given source model~(regular model) and beta. + The method interaface mimic :class:`EMAModelTracker` and :class:`PowerEMATracker`. + Different from two classes, this class does not maintain the EMA model weights as buffers. It expects the user to have two module with same architecture and weights shape. + The class is proposed to work with FSDP model where above two classes are not working as expected. Besides, it is strange to claim model weights as buffers and do unnecessary name changing in :class:`EMAModelTracker` and :class:`PowerEMATracker`. Moeving forward, we should use this class instead of above two classes. + """ + + def __init__(self): + # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite + self.is_cached = False + + def update_average(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module, beta: float = 0.9999) -> None: + target_list = [] + source_list = [] + for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): + assert ( + tgt_params.dtype == torch.float32 + ), f"EMA model only works in FP32 dtype, got {tgt_params.dtype} instead." + target_list.append(tgt_params) + source_list.append(src_params.data) + torch._foreach_mul_(target_list, beta) + torch._foreach_add_(target_list, source_list, alpha=1.0 - beta) + + def copy_to(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module) -> None: + for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): + tgt_params.data.copy_(src_params.data) + + def cache(self, parameters: Any, is_cpu: bool = False) -> None: + """Save the current parameters for restoring later. + + Args: + parameters (iterable): Iterable of torch.nn.Parameter to be temporarily stored. + """ + assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?" + device = "cpu" if is_cpu else "cuda" + self.collected_params = [param.clone().to(device) for param in parameters] + self.is_cached = True + + def restore(self, parameters: Any) -> None: + """Restore the parameters in self.collected_params. + + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before copy_to(). + After validation (or model saving), use this to restore the former parameters. + + Args: + parameters (iterable): Iterable of torch.nn.Parameter to be updated with the stored parameters. + """ + assert self.is_cached, "EMA cache is not taken yet." + for c_param, param in zip(self.collected_params, parameters, strict=False): + param.data.copy_(c_param.data.type_as(param.data)) + self.collected_params = [] + # Release the cache after we call restore + self.is_cached = False + + +def get_buffer_name(param_name: str, torch_compile_buffer_renaming: bool = False) -> str: + """ + This function creates buffer name used by EMA from parameter's name + + Args: + param_name (str): Model's parameter name + Returns: + buffer_name (str): buffer name to be used for given parameter name + """ + + buffer_name = param_name.replace(".", "-") + + if torch_compile_buffer_renaming: + # torch.compile() adds _orig_mod to state dict names, this way we get original name + buffer_name = buffer_name.replace("_orig_mod-", "") + + return buffer_name + + +class EMAModelTracker(torch.nn.Module): + """This is a class to track the EMA model weights. + + The EMA weights are registered as buffers, which are extractable as state dicts. The names follow those of the + regular weights, except all "." are replaced with "-" (limitation of register_buffer()). This is similar to SDXL's + implementation of EMA. There are no optimizable parameters. + + Attributes: + collected_params (list): temporarily stores the regular weights while in EMA mode. + beta (float): EMA decay rate. (default: 0.9999). + torch_compile_buffer_renaming (bool): whether to remove '_orig_mod-' from buffer names when torch.compile is used + """ + + def __init__(self, model: Model, beta: float = 0.9999, torch_compile_buffer_renaming: bool = False): + """Constructor of the EMA model weight tracker. + + Args: + model (Model): The PyTorch model. + beta (float): EMA decay rate. (default: 0.9999). + """ + super().__init__() + self.torch_compile_buffer_renaming: bool = torch_compile_buffer_renaming + if not 0.0 <= beta <= 1.0: + raise ValueError("Decay must be between 0 and 1") + self.beta = beta + for name, param in model.named_parameters(): + if param.requires_grad: + buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) + self.register_buffer(buffer_name, param.clone().detach().data) + self.collected_params = [] + # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite + self.is_cached = False + + @torch.no_grad() + def update_average(self, model: Model, iteration: Optional[int] = None) -> None: + del iteration + target_list = [] + source_list = [] + ema_buffers = self.state_dict() + for name, param in model.named_parameters(): + if param.requires_grad: + buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) + buffer = ema_buffers[buffer_name] + assert buffer.dtype == torch.float32, f"EMA model only works in FP32 dtype, got {buffer.dtype} instead." + target_list.append(buffer) + source_list.append(param.data) + torch._foreach_mul_(target_list, self.beta) + torch._foreach_add_(target_list, source_list, alpha=1.0 - self.beta) + + def copy_to(self, model: Model) -> None: + ema_buffers = self.state_dict() + for name, param in model.named_parameters(): + if param.requires_grad: + buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) + buffer = ema_buffers[buffer_name] + param.data.copy_(buffer.data) + + def cache(self, parameters: Any, is_cpu: bool = False) -> None: + """Save the current parameters for restoring later. + + Args: + parameters (iterable): Iterable of torch.nn.Parameter to be temporarily stored. + """ + assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?" + device = "cpu" if is_cpu else "cuda" + self.collected_params = [param.clone().to(device) for param in parameters] + self.is_cached = True + + def restore(self, parameters: Any) -> None: + """Restore the parameters in self.collected_params. + + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before copy_to(). + After validation (or model saving), use this to restore the former parameters. + + Args: + parameters (iterable): Iterable of torch.nn.Parameter to be updated with the stored parameters. + """ + assert self.is_cached, "EMA cache is not taken yet." + for c_param, param in zip(self.collected_params, parameters, strict=False): + param.data.copy_(c_param.data.type_as(param.data)) + self.collected_params = [] + # Release the cache after we call restore + self.is_cached = False + + @classmethod + def initialize_multi_rank_ema( + cls, model: torch.nn.Module, rate: Union[float, List[float]], num: int = 1, enabled: bool = True + ) -> Optional[EMAModelTracker]: + """ + Class method to initialize per rank EMA Model Tracker with different rate. + Each rank will have a different rate based on the given configuration, resulting in different EMA weights. + + Args: + model (torch.nn.Module): The neural network model to be tracked. + rate (Union[float, List[float]]): The decay rate(s) for the EMA. If a list is provided, + it corresponds to rates for different ranks. + num (int, optional): The number of leading ranks to consider for different rates. + Defaults to 1. + enabled (bool, optional): Flag to enable or disable the creation of the tracker. + If False, returns None. Defaults to True. + + Returns: + Optional[EMAModelTracker]: An instance of EMAModelTracker if enabled, otherwise None. + + Example: + >>> model = torch.nn.Linear(10, 2) + >>> tracker = EMAModelTracker.initialize_ema_from_settings(model, rate=[0.1, 0.2], num=2) + >>> print(tracker) + + Notes: + If `rate` is a list and the current rank is less than `num`, the rate for the current rank + is used. If the current rank exceeds `num`, the first rate in the list is used by default. + """ + if not enabled: + return None + if parallel_state.is_initialized(): + cur_dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + log.critical(f"using MCore parallel_state for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) + log.warning("It should not used together with FSDP!") + else: + cur_dp_rank = distributed.get_rank() + log.critical(f"using torch.distributed for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) + rate = rate if isinstance(rate, list) else [rate] + num = min(num, len(rate)) + rate = rate[cur_dp_rank] if cur_dp_rank < num else rate[0] + if cur_dp_rank < num: + print(f"EMAModelTracker: rank {cur_dp_rank}, rate {rate}") + return cls(model, rate) + + +class PowerEMATracker(EMAModelTracker): + def __init__(self, model: Model, s: float = 0.1, torch_compile_buffer_renaming: bool = False): + """Constructor of the EMA model weight tracker. + + Args: + model (Model): The PyTorch model. + s (float): EMA decay rate. See EDM2 paper + torch_compile_buffer_renaming (bool): whether to remove '_orig_mod-' from buffer names when torch.compile is used + """ + super().__init__(model=model, beta=0.0, torch_compile_buffer_renaming=torch_compile_buffer_renaming) + self.exp = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max() + + @torch.no_grad() + def update_average(self, model: Model, iteration: Optional[int] = None) -> None: + if iteration == 0: + beta = 0.0 + else: + i = iteration + 1 + beta = (1 - 1 / i) ** (self.exp + 1) + self.beta = beta + + super().update_average(model, iteration) + + @classmethod + def initialize_multi_rank_ema( + cls, model: torch.nn.Module, rate: float, num: int, enabled: bool = True + ) -> Optional[PowerEMATracker]: + """ + Class method to initialize per rank EMA Model Tracker with different rate. + Each rank will have a different rate based on the given configuration, resulting in different EMA weights. + + Args: + model (torch.nn.Module): The neural network model for which the EMA tracker is being set up. + num (int): The number of ranks for which the rate adjustment is applied. Beyond this, the rate remains unchanged. + rate (float): The base decay rate for the EMA calculation. + enabled (bool, optional): Flag to enable or disable the initialization of the tracker. If False, returns None. + Defaults to True. + + Returns: + Optional[PowerEMATracker]: An instance of PowerEMATracker with adjusted rate if enabled, otherwise None. + + Raises: + None + + Example: + >>> model = torch.nn.Linear(10, 2) + >>> tracker = PowerEMATracker.initialize_multi_rank_ema(model, num=3, rate=0.99) + >>> print(tracker) + + Notes: + The decay rate is modified by dividing it by 2 raised to the power of the rank for each rank less than `num`. + If the rank is greater than or equal to `num`, the base rate is used without modification. This approach + allows higher ranked processes to have a less aggressive decay, potentially reflecting their delayed synchronization + in a distributed training scenario. + """ + if not enabled: + return None + if parallel_state.is_initialized(): + cur_dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + log.critical(f"using MCore parallel_state for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) + log.warning("It should not used together with FSDP!") + else: + cur_dp_rank = distributed.get_rank() + log.critical(f"using torch.distributed for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) + + divider = 2**cur_dp_rank if cur_dp_rank < num else 1 + if cur_dp_rank < num: + print(f"PowerEMATracker: rank {cur_dp_rank}, rate {rate / divider}") + return cls(model, rate / divider) + + +@contextmanager +def ema_scope(model: Model, enabled: bool = False) -> Generator[None, None, None]: + """Context manager for switching between regular and EMA model weights. + + Args: + model (Model): The PyTorch model. + enabled (bool): Whether switching to EMA weights is enabled (default: False). + """ + if enabled: + assert hasattr(model, "ema") and isinstance(model.ema, (FastEmaModelUpdater, EMAModelTracker, PowerEMATracker)) + model.ema.cache(model.parameters()) + model.ema.copy_to(model) + log.info("EMA: switched to EMA weights.") + try: + yield None + finally: + if enabled: + model.ema.restore(model.parameters()) + log.info("EMA: restored regular weights.") diff --git a/cosmos_transfer1/utils/fused_adam.py b/cosmos_transfer1/utils/fused_adam.py new file mode 100644 index 00000000..76268d93 --- /dev/null +++ b/cosmos_transfer1/utils/fused_adam.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from apex.multi_tensor_apply import multi_tensor_applier + +from cosmos_transfer1.utils import distributed, log + + +class FusedAdam(torch.optim.Optimizer): + """Implements Adam algorithm. + + Currently GPU-only. Requires Apex to be installed via + ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``. + + This version of fused Adam implements 2 fusions. + + * Fusion of the Adam update's elementwise operations + * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters + into one or a few kernel launches. + + :class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, + or ``torch.optim.Adam`` with ``adam_w_mode=False``:: + + opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) + ... + opt.step() + + :class:`apex.optimizers.FusedAdam` may be used with or without Amp. If you wish to use :class:`FusedAdam` with Amp, + you may choose any ``opt_level``:: + + opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) + model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") + ... + opt.step() + + In general, ``opt_level="O1"`` is recommended. + + + .. warning:: + A previous version of :class:`FusedAdam` allowed a number of additional arguments to ``step``. + These additional arguments are now deprecated and unnecessary. + + Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED in FusedAdam! + adam_w_mode (boolean, optional): Apply L2 regularization or weight decay + True for decoupled weight decay(also known as AdamW) (default: True) + capturable (bool, optional): whether to use the version of the optimizer + that can be used with CUDA Graphs. (default: False) + master_weights (bool, optional): whether to maintain FP32 master weights + in the optimizer with FP16 mixed precision training, currently can + only be used with capturable set to True. (default: False) + + .. _Adam - A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + adam_w_mode=True, + weight_decay=0.0, + amsgrad=False, + capturable=False, + master_weights=False, + ): + if amsgrad: + raise RuntimeError("FusedAdam does not support the AMSGrad variant.") + if master_weights and not capturable: + raise RuntimeError("Master weights is currently only supported with the capturable version.") + # If the optimizer is capturable then LR should be a tensor (on GPU) + log.warning(f"FusedAdam master_weights: {master_weights} capturable: {capturable}") + lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr + defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay) + super(FusedAdam, self).__init__(params, defaults) + self.adam_w_mode = 1 if adam_w_mode else 0 + + self.capturable = capturable + self.master_weights = master_weights + + self.param_groups_master = None + + if capturable: + for idx, group in enumerate(self.param_groups): + if len(group["params"]) == 0: + continue + device = group["params"][0].device + for item in ["lr"]: + if isinstance(group[item], float): + group[item] = torch.tensor(group[item], dtype=torch.float32) + self.param_groups[idx][item] = group[item].to(device=device) + + self._step_supports_amp_scaling = True + + if multi_tensor_applier.available: + import amp_C + + # Skip buffer + self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") + self.multi_tensor_adam = amp_C.multi_tensor_adam + self.multi_tensor_adam_capturable = amp_C.multi_tensor_adam_capturable + self.multi_tensor_adam_capturable_master = amp_C.multi_tensor_adam_capturable_master + else: + raise RuntimeError("apex.optimizers.FusedAdam requires cuda extensions") + + def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. + """ + if any(p is not None for p in [grads, output_params, scale, grad_norms]): + raise RuntimeError( + "FusedAdam has been updated. " + "Simply initialize it identically to torch.optim.Adam, and call step() with no arguments." + ) + loss = None + if closure is not None: + loss = closure() + + if self.param_groups_master is None: + # Create full precision master weights + self.param_groups_master = [] + for i, pg in enumerate(self.param_groups): + param_list = pg["params"] + self.param_groups_master.append( + { + "params": [p.clone().detach().float() if self.master_weights else None for p in param_list], + } + ) + + for group, group_master in zip(self.param_groups, self.param_groups_master): + if len(group["params"]) == 0: + continue + device = group["params"][0].device + bias_correction = 1 if "bias_correction" in group and group["bias_correction"] else 0 + beta1, beta2 = group["betas"] + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if "step" in group: + if self.capturable: + group["step"] = ( + group["step"].to(device=device) + if isinstance(group["step"], torch.Tensor) + else torch.tensor(group["step"], dtype=torch.int32, device=device) + ) + group["step"] += (self._dummy_overflow_buf != 1).to(torch.int) + else: + group["step"] += 1 + else: + group["step"] = 1 if not self.capturable else torch.tensor([1], dtype=torch.int, device=device) + + if self.capturable: + group["lr"] = ( + group["lr"].to(device=device) + if isinstance(group["lr"], torch.Tensor) + else torch.tensor(group["lr"], dtype=torch.float32, device=device) + ) + + # create lists for multi-tensor apply + g_16, p_16, m_16, v_16 = [], [], [], [] + g_bf, p_bf, m_bf, v_bf = [], [], [], [] + g_32, p_32, m_32, v_32 = [], [], [], [] + p_16_master = [] + p_32_master = [] + bf16_master = [] + + for p, p_master in zip(group["params"], group_master["params"]): + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise RuntimeError( + "FusedAdam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data).float() + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data).float() + + if p.dtype == torch.float16: + if self.master_weights: + p_16_master.append(p_master.data) + g_16.append(p.grad.data) + p_16.append(p.data) + m_16.append(state["exp_avg"]) + v_16.append(state["exp_avg_sq"]) + elif p.dtype == torch.bfloat16: + if self.master_weights: + bf16_master.append(p_master.data) + g_bf.append(p.grad) + p_bf.append(p) + m_bf.append(state["exp_avg"]) + v_bf.append(state["exp_avg_sq"]) + elif p.dtype == torch.float32: + if self.master_weights: + p_32_master.append(p_master.data) + g_32.append(p.grad.data) + p_32.append(p.data) + m_32.append(state["exp_avg"]) + v_32.append(state["exp_avg_sq"]) + else: + raise RuntimeError("FusedAdam only support fp16 and fp32.") + + # If the optimizer is capturable, then if there's a grad scaler it works + # on the GPU + a different multi_tensor_applier should be called + if self.capturable: + # overflow check of gradients + found_inf = ( + grad_scaler._check_inf_per_device(self)[device] + if grad_scaler is not None + else torch.zeros((1,), device=device) + ) + self._dummy_overflow_buf.copy_(found_inf) + + # get unscale scale factor + scale, inv_scale = None, None + if grad_scaler: + scale = grad_scaler._get_scale_async() + inv_scale = scale.double().reciprocal().float() + else: + scale = torch.ones((1,), device=device, dtype=torch.float32) + inv_scale = torch.ones((1,), device=device, dtype=torch.float32) + + if len(g_16) > 0: + multi_tensor_applier( + ( + self.multi_tensor_adam_capturable_master + if self.master_weights + else self.multi_tensor_adam_capturable + ), + self._dummy_overflow_buf, + [g_16, p_16, m_16, v_16, p_16_master] if self.master_weights else [g_16, p_16, m_16, v_16], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + inv_scale, + ) + + if len(g_bf) > 0: + multi_tensor_applier( + ( + self.multi_tensor_adam_capturable_master + if self.master_weights + else self.multi_tensor_adam_capturable + ), + self._dummy_overflow_buf, + [g_bf, p_bf, m_bf, v_bf, bf16_master] if self.master_weights else [g_bf, p_bf, m_bf, v_bf], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + inv_scale, + ) + + if len(g_32) > 0: + multi_tensor_applier( + ( + self.multi_tensor_adam_capturable_master + if self.master_weights + else self.multi_tensor_adam_capturable + ), + self._dummy_overflow_buf, + [g_32, p_32, m_32, v_32, p_32_master] if self.master_weights else [g_32, p_32, m_32, v_32], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + inv_scale, + ) + else: + if len(g_16) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_16, p_16, m_16, v_16], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) + + if len(g_bf) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_bf, p_bf, m_bf, v_bf], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) + + if len(g_32) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_32, p_32, m_32, v_32], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) + + return loss + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + for group in self.param_groups: + if self.capturable: + group["lr"] = ( + group["lr"].cuda() + if isinstance(group["lr"], torch.Tensor) + else torch.tensor(group["lr"], dtype=torch.float32).cuda() + ) + + if "step" in group: + if self.capturable: + if distributed.get_rank() == 0: + step = ( + group["step"].cuda() + if isinstance(group["step"], torch.Tensor) + else torch.tensor([group["step"]], dtype=torch.int32).cuda() + ) + else: + step = torch.zeros(1, dtype=torch.int32).cuda() + # make it compatible with FSDP optimizer + distributed.broadcast(step, 0) + group["step"] = step + elif isinstance(group["step"], torch.Tensor): + group["step"] = group["step"].item() + for p in group["params"]: + state = self.state[p] + if "exp_avg" in state: + state["exp_avg"] = state["exp_avg"].float() + state["exp_avg_sq"] = state["exp_avg_sq"].float() diff --git a/cosmos_transfer1/utils/lazy_config/lazy.py b/cosmos_transfer1/utils/lazy_config/lazy.py index 1dd72bb0..6db66f42 100644 --- a/cosmos_transfer1/utils/lazy_config/lazy.py +++ b/cosmos_transfer1/utils/lazy_config/lazy.py @@ -18,10 +18,13 @@ import collections.abc as abc import importlib import inspect +import logging import os +import pickle import uuid from collections import OrderedDict from contextlib import contextmanager +from copy import deepcopy from dataclasses import is_dataclass from typing import Any, Dict, List, Tuple, Union @@ -32,6 +35,15 @@ from cosmos_transfer1.utils.lazy_config.file_io import PathManager from cosmos_transfer1.utils.lazy_config.registry import _convert_target_to_string +try: + import dill as dill_pickle +except ImportError: + dill_pickle = None +try: + import cloudpickle +except ImportError: + cloudpickle = None + __all__ = ["LazyCall", "LazyConfig"] @@ -221,6 +233,22 @@ class LazyConfig: which may contain definition of lazily-constructed objects. """ + @staticmethod + def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): + """ + Similar to :meth:`load()`, but load path relative to the caller's + source file. + + This has the same functionality as a relative import, except that this method + accepts filename as a string, so more characters are allowed in the filename. + """ + caller_frame = inspect.stack()[1] + caller_fname = caller_frame[0].f_code.co_filename + assert caller_fname != "", "load_rel Unable to find caller" + caller_dir = os.path.dirname(caller_fname) + filename = os.path.join(caller_dir, filename) + return LazyConfig.load(filename, keys) + @staticmethod def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): """ @@ -274,3 +302,129 @@ def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): flags={"allow_objects": True}, ) return ret + + @staticmethod + def save_pkl(cfg, filename: str) -> str: + """ + Saves a Config object to a file using pickle serialization. This method is typically used + when the configuration object contains complex objects, such as lambdas, that are not supported by + simpler serialization methods like YAML. The function attempts to create a deep copy of the configuration + object before serialization to ensure that the original object remains unmodified. + + Args: + cfg: A Config object to be serialized and saved. + filename: The path and name of the file where the configuration should be saved. The function + assumes the file extension indicates a pickle format (e.g., .pkl). + + Returns: + str: The filename to which the configuration was saved. This can be used to verify the file location + or log the outcome. + + Notes: + - The function logs a warning if the configuration is successfully saved using pickle. + - If saving fails, an error is logged with the exception details. + """ + logger = logging.getLogger(__name__) + try: + cfg = deepcopy(cfg) + except Exception: + pass + + try: + with PathManager.open(filename, "wb") as f: + pickle.dump(cfg, f) + logger.warning(f"Config is saved using pickle at {filename}.") + except Exception as e: + logger.error(f"Failed to save config to {filename}: {e}. Trying dill or cloudpickle instead") + if dill_pickle: + try: + with PathManager.open(filename, "wb") as f: + pickle.dump(dill_pickle.dumps(cfg, recurse=True), f) + logger.warning(f"Config is saved using dill at {filename}.") + except Exception as e: + logger.error(f"Failed to save config to {filename}: {e}.") + if cloudpickle: + try: + with PathManager.open(filename, "wb") as f: + pickle.dump(cloudpickle.dumps(cfg), f) + logger.warning(f"Config is saved using cloudpickle at {filename}.") + except Exception as e: + logger.error(f"Failed to save config to {filename}: {e}.") + else: + logger.error("cloudpickle is not available. Cannot save the config.") + raise e + + return filename + + @staticmethod + def save_yaml(cfg, filename: str) -> str: + """ + Saves a Config object to a file using YAML serialization. This method is beneficial when the configuration object's content needs to be human-readable and easily editable. YAML is suitable for configurations that do not contain complex types like lambdas, which must be handled differently. The function converts unserializable items to strings before saving to ensure compatibility with YAML serialization. + + Args: + cfg: A Config object to be serialized and saved. It handles both DictConfig and ListConfig types. + filename: The path and name of the file where the configuration should be saved. The function does not require a specific file extension but typically uses '.yaml'. + + Returns: + str: The filename to which the configuration was saved. This can be used to verify the file location or log the outcome. + + Notes: + - The function logs a warning if the configuration is successfully saved using YAML. + - If saving fails, an error is logged with the exception details. + """ + logger = logging.getLogger(__name__) + try: + cfg = deepcopy(cfg) + except Exception: + pass + + # Define a function to check if an item is serializable to YAML + def is_serializable(item): + try: + OmegaConf.to_yaml(item) + return True + except Exception as e: + return False + + # Function to convert unserializable items to strings + def serialize_config(config): + if isinstance(config, DictConfig): + for key, value in config.items(): + if isinstance(value, (DictConfig, ListConfig)): + try: + if "_target_" in value: + default_params = get_default_params(value["_target_"]) + for default_key, default_v in default_params.items(): + if default_key not in value: + value[default_key] = default_v + except Exception as e: + logger.error(f"Failed to add default argument values: {e}") + + serialize_config(value) + else: + if not is_serializable(value) and value is not None: + config[key] = str(value) + elif isinstance(config, ListConfig): + for i, item in enumerate(config): + if isinstance(item, (DictConfig, ListConfig)): + serialize_config(item) + else: + if not is_serializable(item) and item is not None: + config[i] = str(item) + else: + raise NotImplementedError("Input config must be a DictConfig or ListConfig.") + return config + + # Convert Config object to a DictConfig object. + config_dict = attrs.asdict(cfg) + config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True}) + + # Serialize the DictConfig object by converting non-serializable objects to strings. + config_omegaconf = serialize_config(config_omegaconf) + + config_dict: Dict[str, Any] = OmegaConf.to_container(config_omegaconf, resolve=True) + sorted_config: OrderedDict[str, Any] = sort_recursive(config_dict) + with open(filename, "w") as f: + yaml.dump(sorted_config, f, default_flow_style=False) + logger.warning(f"Config is saved using omegaconf at {filename}.") + return filename diff --git a/cosmos_transfer1/utils/misc.py b/cosmos_transfer1/utils/misc.py index 83010da1..39bd30ea 100644 --- a/cosmos_transfer1/utils/misc.py +++ b/cosmos_transfer1/utils/misc.py @@ -32,6 +32,9 @@ import termcolor import torch +from torch.distributed._functional_collectives import AsyncCollectiveTensor +from torch.distributed._tensor.api import DTensor + from cosmos_transfer1.utils import distributed, log @@ -100,6 +103,18 @@ def to( return data +def get_local_tensor_if_DTensor(tensor: torch.Tensor | DTensor) -> torch.tensor: + if isinstance(tensor, DTensor): + local = tensor.to_local() + # As per PyTorch documentation, if the communication is not finished yet, we need to wait for it to finish + # https://pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.DTensor.to_local + if isinstance(local, AsyncCollectiveTensor): + return local.wait() + else: + return local + return tensor + + def serialize(data: Any) -> Any: """Serialize data by hierarchically traversing through iterables. diff --git a/cosmos_transfer1/utils/model.py b/cosmos_transfer1/utils/model.py new file mode 100644 index 00000000..54d06ce6 --- /dev/null +++ b/cosmos_transfer1/utils/model.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from cosmos_transfer1.utils.lazy_config import LazyDict, instantiate + + +class Model(torch.nn.Module): + """The base model class. It is inherited from torch.nn.Module. + + All models should inherit Model. It should include the implementions for all the + computation graphs. All inheriting child classes should implement the following methods: + - training_step(): The training step of the model, including the loss computation. + - validation_step(): The validation step of the model, including the loss computation. + - forward(): The computation graph for model inference. + The following methods have default implementations in Model: + - init_optimizer_scheduler(): Creates the optimizer and scheduler for the model. + """ + + def __init__(self) -> None: + super().__init__() + self.on_model_init_start(set_barrier=False) + + def init_optimizer_scheduler( + self, optimizer_config: LazyDict, scheduler_config: LazyDict + ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: + """Creates the optimizer and scheduler for the model. + + Args: + config_model (ModelConfig): The config object for the model. + + Returns: + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + """ + optimizer_config.params = self.parameters() + optimizer = instantiate(optimizer_config) + scheduler_config.optimizer = optimizer + scheduler = instantiate(scheduler_config) + return optimizer, scheduler + + def training_step( + self, data_batch: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """The training step of the model, including the loss computation. + + Args: + data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). + iteration (int): Current iteration number. + + Returns: + output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch. + loss (torch.Tensor): The total loss for backprop (weighted sum of various losses). + """ + raise NotImplementedError + + @torch.no_grad() + def validation_step( + self, data_batch: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """The validation step of the model, including the loss computation. + + Args: + data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). + iteration (int): Current iteration number. + + Returns: + output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch. + loss (torch.Tensor): The total loss (weighted sum of various losses). + """ + raise NotImplementedError + + @torch.inference_mode() + def forward(self, *args: Any, **kwargs: Any) -> Any: + """The computation graph for model inference. + + Args: + *args: Whatever you decide to pass into the forward method. + **kwargs: Keyword arguments are also possible. + + Return: + Your model's output. + """ + raise NotImplementedError + + def on_model_init_start(self, set_barrier=False) -> None: + return + + def on_model_init_end(self, set_barrier=False) -> None: + return + + def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: + """The model preparation before the training is launched + + Args: + memory_format (torch.memory_format): Memory format of the model. + """ + pass + + def on_before_zero_grad( + self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int + ) -> None: + """Hook before zero_grad() is called. + + Args: + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + iteration (int): Current iteration number. + """ + pass + + def on_after_backward(self, iteration: int = 0) -> None: + """Hook after loss.backward() is called. + + This method is called immediately after the backward pass, allowing for custom operations + or modifications to be performed on the gradients before the optimizer step. + + Args: + iteration (int): Current iteration number. + """ + pass diff --git a/cosmos_transfer1/utils/parallel_state_helper.py b/cosmos_transfer1/utils/parallel_state_helper.py new file mode 100644 index 00000000..f531ab00 --- /dev/null +++ b/cosmos_transfer1/utils/parallel_state_helper.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core import parallel_state + + +def is_tp_cp_pp_rank0(): + return ( + parallel_state.get_tensor_model_parallel_rank() == 0 + and parallel_state.get_pipeline_model_parallel_rank() == 0 + and parallel_state.get_context_parallel_rank() == 0 + ) diff --git a/cosmos_transfer1/utils/trainer.py b/cosmos_transfer1/utils/trainer.py new file mode 100644 index 00000000..cdb6af1f --- /dev/null +++ b/cosmos_transfer1/utils/trainer.py @@ -0,0 +1,279 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import os +import signal + +import torch +import torch.distributed as dist +import torch.utils.data +from megatron.core import parallel_state + +from cosmos_transfer1.utils import callback, distributed, ema, log, misc +from cosmos_transfer1.utils.checkpointer import Checkpointer +from cosmos_transfer1.utils.lazy_config import LazyConfig, instantiate +from cosmos_transfer1.utils.model import Model + + +class Trainer: + """The base trainer class. + + All trainers should inherit Trainer. It contains the basic functionality for model training + (particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA), + mixed-precision training (fp16/bf16). + + Attributes: + checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states. + training_timer (misc.Timer): Timer object to time code blocks and functions. + """ + + def __init__(self, config): + """Constructor of the trainer. + + Args: + config (Config): The config object for the codebase. + """ + super().__init__() + self.config = config + # Set up the distributed computing environment. + with misc.timer("init_distributed"): + distributed.init() + # Set up parallel states. + if hasattr(config.model, "context_parallel_size"): + if config.model_parallel.context_parallel_size > 1: + raise ValueError( + "Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. " + "config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size." + ) + else: + log.critical( + "Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead." + ) + config.model_parallel.context_parallel_size = config.model.context_parallel_size + parallel_state.initialize_model_parallel( + pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size, + tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size, + context_parallel_size=config.model_parallel.context_parallel_size, + ) + # `config.model_parallel.sequence_parallel` is a bool that indicates whether to use sequence parallelism. + # It is not part of the original `parallel_state` API, so we need to set it manually. + parallel_state.sequence_parallel = config.model_parallel.sequence_parallel + if parallel_state.sequence_parallel: + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + + # Create the local job directory, save the config file, and pipe to a local log. + if distributed.is_rank0(): + os.makedirs(config.job.path_local, exist_ok=True) + # Save the config as .pkl for reproducibility. + LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl") + # Save the config as .yaml for reading or parsing experiment hyperparameters. + LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") + dist.barrier() + log.init_loguru_file(f"{config.job.path_local}/stdout.log") + if distributed.is_rank0(): + # Print important environment variables and the effective config. + log.info("Config:\n" + config.pretty_print(use_color=True)) + misc.print_environ_variables(["TORCH_HOME", "OUTPUT_ROOT"]) + # Set the random seed. If multi-GPU, different ranks are set with different seeds. + misc.set_random_seed(seed=config.trainer.seed, by_rank=True) + # Initialize cuDNN. + torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic + torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark + # Floating-point precision settings. + torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True + # Initialize the callback functions. + self.callbacks = callback.CallBackGroup(config=config, trainer=self) + # Initialize the model checkpointer. + if config.checkpoint.type is None: + self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks) + else: + self.checkpointer: Checkpointer = instantiate( + config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks + ) + # Initialize the timer for speed benchmarking. + self.training_timer = misc.TrainingTimer() + # Send a TimeoutError if a training step takes over timeout_period seconds. + signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) # type: ignore + + def train( + self, + model: Model, + dataloader_train: torch.utils.data.DataLoader, + dataloader_val: torch.utils.data.DataLoader, + ) -> None: + """The training function. + + Args: + model (Model): The PyTorch model. + dataloader_train (torch.utils.data.DataLoader): The training data loader. + dataloader_val (torch.utils.data.DataLoader): The validation data loader. + """ + # Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models. + model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore + model.on_train_start(self.config.trainer.memory_format) + + # Initialize the optimizer, scheduler, and grad_scaler. + self.callbacks.on_optimizer_init_start() + optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler) + grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args) + self.callbacks.on_optimizer_init_end() + # Load the model checkpoint and get the starting iteration number. + iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler) + grad_accum_iter = 0 + log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}") + if self.config.trainer.distributed_parallelism == "ddp": + # Create a DDP model wrapper. + model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model) + elif self.config.trainer.distributed_parallelism == "fsdp": + model_ddp = model + else: + raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}") + log.info("Starting training...") + self.callbacks.on_train_start(model, iteration=iteration) + # Initial validation. + if self.config.trainer.run_validation and iteration == 0: + self.validate(model, dataloader_val, iteration=iteration) + _end_training = False + while True: + dataloader_train_iter = iter(dataloader_train) + while True: + self.callbacks.on_before_dataloading(iteration) + with self.training_timer("dataloader_train"): + try: + data_batch = next(dataloader_train_iter) + except StopIteration: + break + self.callbacks.on_after_dataloading(iteration) + # If max_iter is reached, exit the training loop. + if iteration >= self.config.trainer.max_iter: + _end_training = True + break + # Move all tensors in the data batch to GPU device. + data_batch = misc.to(data_batch, device="cuda") + # The actual training step. + self.callbacks.on_training_step_start(model, data_batch, iteration=iteration) + if not model.training: + model_ddp.train() + assert model_ddp.training, "model_ddp is not in training mode." + assert model.training, "model is not in training mode." + output_batch, loss, grad_accum_iter = self.training_step( + model_ddp, + optimizer, + scheduler, + grad_scaler, + data_batch, + iteration=iteration, + grad_accum_iter=grad_accum_iter, + ) + # Do the following when an actual optimizer (update) step has been made. + iteration += 1 + # Save checkpoint. + if iteration % self.config.checkpoint.save_iter == 0: + self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) + self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration) + # Validation. + if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0: + self.validate(model, dataloader_val, iteration=iteration) + # This iteration is successful; reset the timeout signal. + signal.alarm(self.config.trainer.timeout_period) + if _end_training: + break + log.success("Done with training.") + if iteration % self.config.checkpoint.save_iter != 0: + self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) + self.callbacks.on_train_end(model, iteration=iteration) + self.checkpointer.finalize() + distributed.barrier() + self.callbacks.on_app_end() + + def training_step( + self, + model_ddp: torch.nn.Module | distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + data: dict[str, torch.Tensor], + iteration: int = 0, + grad_accum_iter: int = 0, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]: + """The training step. + + Args: + model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare + module, depending on whether distributed training is enabled or not. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). + iteration (int): Current iteration number. + grad_accum_iter (int): Number of gradient accumulation iterations. + + Returns: + output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors). + loss (torch.Tensor): The total loss of the training data batch. + """ + # Only let DDP sync gradient at the last iteration of the gradient accumulation window + with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1): + with self.training_timer("forward"): + output_batch, loss = model_ddp.training_step(data, iteration) + self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration) + with self.training_timer("backward"): + loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter) + loss_scaled.backward() + if self.config.trainer.distributed_parallelism == "ddp": + model_ddp.module.on_after_backward() + else: + model_ddp.on_after_backward() + self.callbacks.on_after_backward(model_ddp, iteration=iteration) + grad_accum_iter += 1 + if grad_accum_iter == self.config.trainer.grad_accum_iter: + with self.training_timer("optimizer_step"): + self.callbacks.on_before_optimizer_step( + model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration + ) + grad_scaler.step(optimizer) + grad_scaler.update() + scheduler.step() + self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration) + if self.config.trainer.distributed_parallelism == "ddp": + model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration) + else: + model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration) + optimizer.zero_grad(set_to_none=True) + grad_accum_iter = 0 + return output_batch, loss, grad_accum_iter + + @torch.no_grad() + def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: + """Validate on the full validation dataset. + + Args: + model (Model): The PyTorch model. + dataloader_val (torch.utils.data.DataLoader): The validation data loader. + iteration (int): Current iteration number. + """ + self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) + model.eval() + # Evaluate on the full validation set. + with ema.ema_scope(model, enabled=model.config.ema.enabled): + for val_iter, data_batch in enumerate(dataloader_val): + if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: + break + data_batch = misc.to(data_batch, device="cuda") + self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) + output_batch, loss = model.validation_step(data_batch, iteration) + self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) + self.callbacks.on_validation_end(model, iteration=iteration) diff --git a/examples/post-training_cosmos_transfer_7b_edge.md b/examples/post-training_cosmos_transfer_7b_edge.md new file mode 100644 index 00000000..91564e63 --- /dev/null +++ b/examples/post-training_cosmos_transfer_7b_edge.md @@ -0,0 +1,211 @@ +## Post-training diffusion-based EdgeControl models + +### Model Support Matrix + +We support the following Cosmos Diffusion models for post-training. Review the available models and their compute requirements for post-tuning and inference to determine the best model for your use case. + +| Model Name | Model Status | Compute Requirements for Post-Training | +|----------------------------------------------|------------------|------------------------------------------| +| Cosmos-Transfer1-7B | **Supported** | 8 NVIDIA GPUs* | + +**\*** `H100-80GB` or `A100-80GB` GPUs are recommended. + +### Environment setup + +Please refer to the Post-training section of [INSTALL.md](/INSTALL.md#post-training) for instructions on environment setup. + +### Download Checkpoints + +1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token. Set the access token to 'Read' permission (default is 'Fine-grained'). + +2. Log in to Hugging Face with the access token: + +```bash +huggingface-cli login +``` + +3. Accept the [LlamaGuard-7b terms](https://huggingface.co/meta-llama/LlamaGuard-7b) + +4. Download the Cosmos model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-transfer1-67c9d328196453be6e568d3e): + +```bash +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_checkpoints.py --output_dir checkpoints/ +``` + +Note that this will require about 300GB of free storage. Not all these checkpoints will be used in every generation. + +5. The downloaded files should be in the following structure: + +``` +checkpoints/ +├── nvidia +│ ├── Cosmos-Transfer1-7B +│ │ ├── base_model.pt +│ │ ├── vis_control.pt +│ │ ├── edge_control.pt +│ │ ├── seg_control.pt +│ │ ├── depth_control.pt +│ │ ├── keypoint_control.pt +│ │ ├── 4kupscaler_control.pt +│ │ ├── config.json +│ │ └── guardrail +│ │ ├── aegis/ +│ │ ├── blocklist/ +│ │ ├── face_blur_filter/ +│ │ └── video_content_safety_filter/ +│ │ +│ ├── Cosmos-Transfer1-7B-Sample-AV/ +│ │ ├── base_model.pt +│ │ ├── hdmap_control.pt +│ │ └── lidar_control.pt +│ │ +│ │── Cosmos-Tokenize1-CV8x8x8-720p +│ │ ├── decoder.jit +│ │ ├── encoder.jit +│ │ ├── autoencoder.jit +│ │ └── mean_std.pt +│ │ +│ └── Cosmos-UpsamplePrompt1-12B-Transfer +│ ├── depth +│ │ ├── consolidated.safetensors +│ │ ├── params.json +│ │ └── tekken.json +│ ├── README.md +│ ├── segmentation +│ │ ├── consolidated.safetensors +│ │ ├── params.json +│ │ └── tekken.json +│ ├── seg_upsampler_example.png +│ └── viscontrol +│ ├── consolidated.safetensors +│ ├── params.json +│ └── tekken.json +│ +├── depth-anything/... +├── facebook/... +├── google-t5/... +└── IDEA-Research/ +``` + +### Examples + +Post-training a Cosmos-Transfer1 model enables you to train the model to generate videos that are more specific to your use case. + +There are 3 steps to post-training: downloading a dataset, preprocessing the data, and post-training the model. + +#### 1. Download a Dataset + +The first step is to download a dataset with videos and captions. + +You must provide a folder containing a collection of videos in **MP4 format**, preferably 720p. These videos should focus on the subject throughout the entire video so that each video chunk contains the subject. + +For example, you can use a subset of [HD-VILA-100M](https://github.com/microsoft/XPretrain/tree/main/hd-vila-100m) dataset for post-training. + +```bash +# Download metadata with video urls and captions +mkdir -p datasets/hdvila +cd datasets/hdvila +wget https://huggingface.co/datasets/TempoFunk/hdvila-100M/resolve/main/hdvila-100M.jsonl +``` + +Run the following command to download the sample videos used for post-training: + +```bash +# Requirements for Youtube video downloads & video clipping +pip install pytubefix ffmpeg +``` + +```bash +# The script will downlaod the original HD-VILA-100M videos, save the corresponding clips, the captions and the metadata. +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_diffusion_example_data.py --dataset_path datasets/hdvila --N_videos 128 --do_download --do_clip +``` + +#### 2. Preprocessing the Data + +Run the following command to pre-compute T5-XXL embeddings for the video captions used for post-training: + +```bash +# The script will read the captions, save the T5-XXL embeddings in pickle format. +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/get_t5_embeddings.py --dataset_path datasets/hdvila +``` + +Dataset folder format: +``` +datasets/hdvila/ +├── metas/ +│ ├── *.json +│ ├── *.txt +├── videos/ +│ ├── *.mp4 +├── t5_xxl/ +│ ├── *.pickle +``` + +Training a VisControl or EdgeControl model is self-supervised: we apply blurs and/or compute canny edges of the input videos on-the-fly during training. Therefore, for these two modalities there is no need to prepare the control input videos separately. + +#### 3. Post-train the Model + +Run the following command to execute an example post-training job with the above data. +```bash +export OUTPUT_ROOT=checkpoints # default value +torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/training/config/config.py --experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3 +``` + +checkpoints/cosmos_transfer1/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3/config.yaml +This command will use ``cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py` to register experiments for all `hint_keys` (control modalities). + +Then the model will be post-trained using the above hdvila dataset. +See the function `make_ctrlnet_config_7b_training` defined in `cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py` to understand how the detailed configs of the model, trainer, dataloader etc. are defined. For the data specifically: + +```python +num_frames = 121 +example_video_dataset = L(Dataset)( + dataset_dir="datasets/hdvila", + sequence_interval=1, + num_frames=num_frames, + video_size=(720, 1280), + start_frame_interval=1, +) + +dataloader_train = L(DataLoader)( + dataset=example_video_dataset, + sampler=L(get_sampler)(dataset=example_video_dataset), + batch_size=1, + drop_last=True, +) +... + +config = LazyDict( + dict( + ... + dataloader_train=dataloader_train, + ... + ) +) +... +``` + +The checkpoints will be saved to `${OUTPUT_ROOT}/PROJECT/GROUP/NAME`. +In the above example, `PROJECT` is `cosmos_transfer1_posttrain`, `GROUP` is `CTRL_7Bv1_lvg`, `NAME` is `CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3`. + +See the job config to understand how they are determined. +```python +edgecontrol_7b_example_hdvila = LazyDict( + dict( + ... + job=dict( + project="cosmos_transfer1_posttrain", + group="CTRL_7Bv1_lvg", + name="CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3", + ), + ... + ) +) +``` + +During the training, the checkpoints will be saved in the below structure. +``` +checkpoints/cosmos_transfer1_posttrain/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3/checkpoints/ +├── iter_{NUMBER}_reg_model.pt +├── iter_{NUMBER}_ema_model.pt +``` \ No newline at end of file diff --git a/scripts/get_t5_embeddings.py b/scripts/get_t5_embeddings.py new file mode 100644 index 00000000..53b6ebe8 --- /dev/null +++ b/scripts/get_t5_embeddings.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pickle +from typing import Tuple + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +"""example command +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/get_t5_embeddings.py --dataset_path datasets/hdvila +""" + + +def parse_args() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compute T5 embeddings for text prompts") + parser.add_argument("--dataset_path", type=str, default="datasets/hdvila", help="Root path to the dataset") + parser.add_argument("--max_length", type=int, default=512, help="Maximum length of the text embedding") + parser.add_argument( + "--pretrained_model_name_or_path", type=str, default="google-t5/t5-11b", help="T5 model name or the local path" + ) + parser.add_argument("--cache_dir", type=str, default="checkpoints", help="Directory to cache the T5 model") + return parser.parse_args() + + +def init_t5( + pretrained_model_name_or_path: str = "google-t5/t5-11b", max_length: int = 512, cache_dir: str = "~/.cache" +) -> Tuple[T5TokenizerFast, T5EncoderModel]: + """Initialize and return the T5 tokenizer and text encoder.""" + tokenizer = T5TokenizerFast.from_pretrained( + pretrained_model_name_or_path, model_max_length=max_length, cache_dir=cache_dir + ) + text_encoder = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir) + text_encoder.to("cuda") + text_encoder.eval() + return tokenizer, text_encoder + + +@torch.inference_mode() +def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512) -> list: + """ + Encode a batch of text prompts to a batch of T5 embeddings. + Parameters: + tokenizer: T5 embedding tokenizer. + encoder: T5 embedding text encoder. + prompts: A batch of text prompts. + max_length: Sequence length of text embedding (defaults to 512). + """ + + batch_encoding = tokenizer.batch_encode_plus( + prompts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=max_length, + return_length=True, + return_offsets_mapping=False, + ) + + # We expect all the processing is done on GPU. + input_ids = batch_encoding.input_ids.cuda() + attn_mask = batch_encoding.attention_mask.cuda() + + outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) + encoded_text = outputs.last_hidden_state + + lengths = attn_mask.sum(dim=1).cpu() + for batch_id in range(encoded_text.shape[0]): + encoded_text[batch_id][lengths[batch_id] :] = 0 + + encoded_text = encoded_text.cpu().numpy().astype(np.float16) + encoded_text = encoded_text[:, :max_length] + + # trim zeros to save space + encoded_text = [encoded_text[batch_id][: lengths[batch_id]] for batch_id in range(encoded_text.shape[0])] + + return encoded_text + + +def main(args) -> None: + metas_dir = os.path.join(args.dataset_path, "metas") + metas_list = [ + os.path.join(metas_dir, filename) for filename in sorted(os.listdir(metas_dir)) if filename.endswith(".txt") + ] + + t5_xxl_dir = os.path.join(args.dataset_path, "t5_xxl") + os.makedirs(t5_xxl_dir, exist_ok=True) + + # Initialize T5 + tokenizer, text_encoder = init_t5(cache_dir=args.cache_dir) + + for meta_filename in metas_list: + t5_xxl_filename = os.path.join(t5_xxl_dir, os.path.basename(meta_filename).replace(".txt", ".pickle")) + if os.path.exists(t5_xxl_filename): + # Skip if the file already exists + continue + + with open(meta_filename, "r") as fp: + prompt = fp.read().strip() + + # Compute T5 embeddings + encoded_text = encode_for_batch(tokenizer, text_encoder, [prompt]) + + # Save T5 embeddings as pickle file + with open(t5_xxl_filename, "wb") as fp: + pickle.dump(encoded_text, fp) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/test_environment.py b/scripts/test_environment.py index 74fd6452..d3f99277 100644 --- a/scripts/test_environment.py +++ b/scripts/test_environment.py @@ -13,10 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import importlib import os import sys +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--training", + action="store_true", + help="Whether to check training-specific dependencies", + ) + return parser.parse_args() + def check_packages(package_list): global all_success @@ -29,6 +39,8 @@ def check_packages(package_list): else: print(f"\033[92m[SUCCESS]\033[0m {package} found") +args = parse_args() + if not (sys.version_info.major == 3 and sys.version_info.minor >= 10): detected = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" print(f"\033[91m[ERROR]\033[0m Python 3.10+ is required. You have: \033[93m{detected}\033[0m") @@ -47,9 +59,14 @@ def check_packages(package_list): "transformer_engine", "vllm", ] +packages_training = [ + "apex.multi_tensor_apply", +] all_success = True check_packages(packages) +if args.training: + check_packages(packages_training) if all_success: print("-----------------------------------------------------------") From e41cc0c611f6d6f10f2780f8015fafd758ac650f Mon Sep 17 00:00:00 2001 From: Qianli Ma Date: Thu, 10 Apr 2025 18:05:23 -0700 Subject: [PATCH 02/10] feat: add separate model definitions supporting tp/sp for training; update configs --- cosmos_transfer1/checkpointer/fast_tp.py | 113 +++ cosmos_transfer1/diffusion/conditioner.py | 5 +- .../diffusion/config/config_train.py | 2 +- .../diffusion/config/training/checkpoint.py | 3 +- .../experiment/ctrl_7b_tp_121frames.py | 4 +- .../diffusion/config/training/registry.py | 11 +- .../diffusion/model/model_ctrl.py | 2 +- .../networks/general_dit_video_conditioned.py | 3 +- .../diffusion/training/models/extend_model.py | 576 +++++++++++ .../training/models/extend_model_multiview.py | 448 +++++++++ .../diffusion/training/models/model.py | 660 +++++++++++++ .../diffusion/training/models/model_ctrl.py | 720 ++++++++++++++ .../diffusion/training/models/model_image.py | 930 ++++++++++++++++++ .../training/models/model_multiview.py | 224 +++++ .../diffusion/training/utils/fsdp_helper.py | 159 +++ 15 files changed, 3846 insertions(+), 14 deletions(-) create mode 100644 cosmos_transfer1/checkpointer/fast_tp.py create mode 100644 cosmos_transfer1/diffusion/training/models/extend_model.py create mode 100644 cosmos_transfer1/diffusion/training/models/extend_model_multiview.py create mode 100644 cosmos_transfer1/diffusion/training/models/model.py create mode 100644 cosmos_transfer1/diffusion/training/models/model_ctrl.py create mode 100644 cosmos_transfer1/diffusion/training/models/model_image.py create mode 100644 cosmos_transfer1/diffusion/training/models/model_multiview.py create mode 100644 cosmos_transfer1/diffusion/training/utils/fsdp_helper.py diff --git a/cosmos_transfer1/checkpointer/fast_tp.py b/cosmos_transfer1/checkpointer/fast_tp.py new file mode 100644 index 00000000..16732924 --- /dev/null +++ b/cosmos_transfer1/checkpointer/fast_tp.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Set + +import torch + +from cosmos_transfer1.checkpointer.ddp_checkpointer import StateDictItemPath +from cosmos_transfer1.checkpointer.tp_checkpointer import Checkpointer as TPCheckpointer +from cosmos_transfer1.utils import distributed, log, misc +from cosmos_transfer1.utils.easy_io import easy_io +from cosmos_transfer1.diffusion.training.models.model import DiffusionModel + + + +class Checkpointer(TPCheckpointer): + def load_broadcast_state_dict( + self, checkpoint_path: str, model: DiffusionModel, resume_keys: Set + ) -> dict[str, Any]: + """ + Load state_dict and broadcast efficiently. + + This method optimizes checkpoint loading for distributed training for improved + connection speed and reliability. + + The main steps are: + 1. Retrieve TP-rank-specific checkpoints for each GPU of DDP-rank 0 + and CP-rank 0. + 2. Each rank loads its corresponding checkpoint either from a local cache or + receives it via broadcast. + + This approach ensures that each MP (Model Parallelism) rank loads its specific + part of the model, which is crucial for scenarios where different parts of the + model are distributed across multiple GPUs. + + The method supports both Tensor Parallelism (TP) and standard Data Parallel (DP) + training. For TP, each rank can efficiently load its specific checkpoint from S3. + For standard DDP without TP, the default broadcast mechanism is used. + + Args: + checkpoint_path (str): The base path of the checkpoint in S3. + model (DiffusionModel): The model being loaded. + resume_keys (Set): Set of keys to resume from the checkpoint. + + Returns: + dict[str, Any]: A dictionary containing the loaded state for each resumed key. + + Note: + This implementation has been tested and optimized for 4K GPU training jobs, + showing significant improvements in connection speed and overall efficiency. + """ + state_dict = {} + sorted_resume_keys = sorted(resume_keys) + for key in sorted_resume_keys: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + _state_dict = easy_io.load(_ckpt_path, fast_backend=True, backend_key=self.load_s3_backend_key) + state_dict[key] = _state_dict + self.print(f"Loaded checkpoint from: {_ckpt_path}") + distributed.barrier() + return state_dict + + @misc.timer("checkpoint saving") + def _save_worker(self, state_dict: dict[str, StateDictItemPath], checkpoint_file: str, rank: int = 0) -> None: + """ + similar to the original _save_worker, but with the following changes: + * fast_backend=False to avoid high CPU usage + """ + try: + for key, item in state_dict.items(): + self.print(f"Saving {key} to {item.save_path}") + try: + easy_io.dump( + item.state_dict, + item.save_path, + fast_backend=False, # too cpu heavy + backend_key=self.save_s3_backend_key, + ) + self.print(f"Saved {key} to {item.save_path}") + except Exception as e: + self.print(f"Failed to save {key} to {item.save_path}: {str(e)}") + raise # Re-raise the exception after logging + + # Synchronize only rank 0 of each model parallel group + if self.mp_world_size > 1: + torch.distributed.barrier(group=self.mp_gloo_pg) + + # Only rank 0 of MP group and rank 0 of DP with CP updates latest_checkpoint.txt + if self.mp_rank == 0 and self.rank_dp_w_cp == 0: + self._write_latest_checkpoint_file(checkpoint_file) + + if distributed.get_rank() == 0: # only rank 0 saves trained_data_record + if "trained_data_record" in state_dict["model"].state_dict: + self._write_trained_data_record( + checkpoint_file, state_dict["model"].state_dict["trained_data_record"] + ) + + iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) + self.callbacks.on_save_checkpoint_success(iteration=iteration) + except Exception as e: # noqa: BLE001 + log.exception(f"Checkpoint failed to upload: {e}", rank0_only=not self.verbose) diff --git a/cosmos_transfer1/diffusion/conditioner.py b/cosmos_transfer1/diffusion/conditioner.py index 00bfe588..82016cce 100644 --- a/cosmos_transfer1/diffusion/conditioner.py +++ b/cosmos_transfer1/diffusion/conditioner.py @@ -116,6 +116,7 @@ class BaseVideoCondition: num_frames: Optional[torch.Tensor] = None image_size: Optional[torch.Tensor] = None scalar_feature: Optional[torch.Tensor] = None + frame_repeat: Optional[torch.Tensor] = None def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: return {f.name: getattr(self, f.name) for f in fields(self)} @@ -351,4 +352,6 @@ def forward( output["hint_key"] = batch["hint_key"] if "control_weight" in batch: output["control_weight"] = batch["control_weight"] - return BaseWithCtrlCondition(**output) + if "num_layers_to_use" in batch: + output["num_layers_to_use"] = batch["num_layers_to_use"] + return BaseWithCtrlCondition(**output) \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/config/config_train.py b/cosmos_transfer1/diffusion/config/config_train.py index 3d2defc8..35bfffe4 100644 --- a/cosmos_transfer1/diffusion/config/config_train.py +++ b/cosmos_transfer1/diffusion/config/config_train.py @@ -20,7 +20,7 @@ from cosmos_transfer1.diffusion.config.transfer.model import CtrlModelConfig from cosmos_transfer1.checkpointer.ema_fsdp_checkpointer import CheckpointConfig from cosmos_transfer1.diffusion.config.training.registry_extra import register_configs -from cosmos_transfer1.diffusion.model.model_ctrl import VideoDiffusionModelWithCtrl +from cosmos_transfer1.diffusion.training.models.model_ctrl import VideoDiffusionModelWithCtrl from cosmos_transfer1.utils import config from cosmos_transfer1.utils.config_helper import import_all_modules_from_package from cosmos_transfer1.utils.lazy_config import PLACEHOLDER diff --git a/cosmos_transfer1/diffusion/config/training/checkpoint.py b/cosmos_transfer1/diffusion/config/training/checkpoint.py index 7248fe53..65aff093 100644 --- a/cosmos_transfer1/diffusion/config/training/checkpoint.py +++ b/cosmos_transfer1/diffusion/config/training/checkpoint.py @@ -20,8 +20,9 @@ from cosmos_transfer1.checkpointer.fsdp_checkpointer import FSDPCheckpointer from cosmos_transfer1.checkpointer.multi_rank_checkpointer import MultiRankCheckpointer from cosmos_transfer1.checkpointer.tp_checkpointer import Checkpointer as TPCheckpointer - +from cosmos_transfer1.checkpointer.fast_tp import Checkpointer as FastTPCheckpointer MULTI_RANK_CHECKPOINTER: Dict[str, str] = L(MultiRankCheckpointer)() FSDP_CHECKPOINTER: Dict[str, str] = L(FSDPCheckpointer)() MODEL_PARALLEL_CHECKPOINTER: Dict[str, str] = L(TPCheckpointer)() +FAST_TP_CHECKPOINTER: Dict[str, str] = L(FastTPCheckpointer)() \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py index ec4f4854..f161f848 100644 --- a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py +++ b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py @@ -29,7 +29,7 @@ from cosmos_transfer1.utils.lazy_config import LazyDict from cosmos_transfer1.diffusion.config.transfer.blurs import random_blur_config from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS_COMB -from cosmos_transfer1.diffusion.model.model_ctrl import VideoDiffusionModelWithCtrl +from cosmos_transfer1.diffusion.training.models.model_ctrl import VideoDiffusionModelWithCtrl # this one has training support from cosmos_transfer1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT cs = ConfigStore.instance() @@ -65,7 +65,7 @@ def make_ctrlnet_config_7b_training( {"override /hint_key": hint_key}, {"override /callbacks": "basic"}, {"override /checkpoint": "local"}, - {"override /ckpt_klass": "multi_rank"}, + {"override /ckpt_klass": "fast_tp"}, # {"override /data_train": data_train}, {"override /data_val": data_val}, diff --git a/cosmos_transfer1/diffusion/config/training/registry.py b/cosmos_transfer1/diffusion/config/training/registry.py index 9d5be5c0..c33e1d56 100644 --- a/cosmos_transfer1/diffusion/config/training/registry.py +++ b/cosmos_transfer1/diffusion/config/training/registry.py @@ -28,6 +28,7 @@ FSDP_CHECKPOINTER, MULTI_RANK_CHECKPOINTER, MODEL_PARALLEL_CHECKPOINTER, + FAST_TP_CHECKPOINTER, ) @@ -59,13 +60,9 @@ def register_checkpoint_credential(cs): def register_checkpointer(cs): cs.store(group="ckpt_klass", package="checkpoint.type", name="fsdp", node=FSDP_CHECKPOINTER) cs.store(group="ckpt_klass", package="checkpoint.type", name="multi_rank", node=MULTI_RANK_CHECKPOINTER) - cs.store( - group="ckpt_klass", - package="checkpoint.type", - name="tp", - node=MODEL_PARALLEL_CHECKPOINTER, - ) - + cs.store(group="ckpt_klass", package="checkpoint.type", name="tp", node=MODEL_PARALLEL_CHECKPOINTER) + cs.store(group="ckpt_klass", package="checkpoint.type", name="fast_tp", node=FAST_TP_CHECKPOINTER) + def register_configs(): cs = ConfigStore.instance() diff --git a/cosmos_transfer1/diffusion/model/model_ctrl.py b/cosmos_transfer1/diffusion/model/model_ctrl.py index c8f840cc..056c9898 100644 --- a/cosmos_transfer1/diffusion/model/model_ctrl.py +++ b/cosmos_transfer1/diffusion/model/model_ctrl.py @@ -495,7 +495,7 @@ def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor): if getattr(uncondition, hint_key) is not None: setattr(uncondition, hint_key, latent_hint[idx : idx + 1]) - if self.is_image_batch(data_batch) or not issubclass(base_class, ExtendVideoDiffusionModel): + if self.is_image_batch(data_batch): cond_x0 = self.denoise(noise_x, sigma, condition).x0 uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 else: diff --git a/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py b/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py index 9cb2fe90..c22aad4c 100644 --- a/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py +++ b/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py @@ -191,7 +191,8 @@ def forward_before_blocks( if self.add_augment_sigma_embedding: if condition_video_augment_sigma is None: # Handling image case - # Note: for video case, when there is not condition frames, we also set it as zero, see DiffusionV2WModel augment_conditional_latent_frames function + # Note: for video case, when there is not condition frames, we also set it as zero, see + # the augment_conditional_latent_frames function in DiffusionV2WModel and ExtendDiffusionModel. assert data_type == DataType.IMAGE, "condition_video_augment_sigma is required for video data type" condition_video_augment_sigma = torch.zeros_like(timesteps.flatten()) diff --git a/cosmos_transfer1/diffusion/training/models/extend_model.py b/cosmos_transfer1/diffusion/training/models/extend_model.py new file mode 100644 index 00000000..a637f5cb --- /dev/null +++ b/cosmos_transfer1/diffusion/training/models/extend_model.py @@ -0,0 +1,576 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from statistics import NormalDist +from typing import Callable, Dict, Optional, Tuple, Union + +import numpy as np +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_transfer1.diffusion.config.base.conditioner import VideoCondBoolConfig +from cosmos_transfer1.diffusion.functional.batch_ops import batch_mul +from cosmos_transfer1.diffusion.conditioner import DataType, VideoExtendCondition +from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_transfer1.diffusion.training.models.model import DiffusionModel as BaseModel +from cosmos_transfer1.diffusion.training.models.model import _broadcast, broadcast_condition +from cosmos_transfer1.diffusion.training.models.model_image import diffusion_fsdp_class_decorator +from cosmos_transfer1.utils import log, misc + + +@dataclass +class VideoDenoisePrediction: + x0: torch.Tensor # clean data prediction + eps: Optional[torch.Tensor] = None # noise prediction + logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty + net_in: Optional[torch.Tensor] = None # input to the network + net_x0_pred: Optional[torch.Tensor] = None # prediction of x0 from the network + xt: Optional[torch.Tensor] = None # input to the network, before muliply with c_in + x0_pred_replaced: Optional[torch.Tensor] = None # x0 prediction with condition region replaced by gt_latent + + +def normalize_condition_latent(condition_latent): + """Normalize the condition latent tensor to have zero mean and unit variance + Args: + condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W + """ + condition_latent_2D = rearrange(condition_latent, "b c t h w -> b c t (h w)") + mean = condition_latent_2D.mean(dim=-1) + std = condition_latent_2D.std(dim=-1) + # bct -> bct11 + mean = mean.unsqueeze(-1).unsqueeze(-1) + std = std.unsqueeze(-1).unsqueeze(-1) + condition_latent = (condition_latent - mean) / std + return condition_latent + + +class ExtendDiffusionModel(BaseModel): + def __init__(self, config): + super().__init__(config) + self.is_extend_model = True + + def get_data_and_condition( + self, data_batch: dict[str, Tensor], num_condition_t: Union[int, None] = None + ) -> Tuple[Tensor, VideoExtendCondition]: + raw_state, latent_state, condition = super().get_data_and_condition(data_batch) + if condition.data_type == DataType.VIDEO: + if self.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: + latent_state = self.sample_tokens_start_from_p_or_i(latent_state) + condition = self.add_condition_video_indicator_and_video_input_mask( + latent_state, condition, num_condition_t=num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + condition = self.add_condition_pose(data_batch, condition) + log.debug(f"condition.data_type {condition.data_type}") + return raw_state, latent_state, condition + + def draw_augment_sigma_and_epsilon( + self, size: int, condition: VideoExtendCondition, p_mean: float, p_std: float, multiplier: float + ) -> Tensor: + is_video_batch = condition.data_type == DataType.VIDEO + del condition + batch_size = size[0] + epsilon = torch.randn(size, **self.tensor_kwargs) + + gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) + cdf_vals = np.random.uniform(size=(batch_size)) + samples_interval_gaussian = [gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] + + log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") + sigma_B = torch.exp(log_sigma).to(**self.tensor_kwargs) + + sigma_B = _broadcast(sigma_B * multiplier, to_tp=True, to_cp=is_video_batch) + epsilon = _broadcast(epsilon, to_tp=True, to_cp=is_video_batch) + return sigma_B, epsilon + + def augment_conditional_latent_frames( + self, + condition: VideoExtendCondition, + cfg_video_cond_bool: VideoCondBoolConfig, + gt_latent: Tensor, + condition_video_augment_sigma_in_inference: float = 0.001, + sigma: Tensor = None, + seed_inference: int = 1, + ) -> Union[VideoExtendCondition, Tensor]: + """This function is used to augment the condition input with noise + Args: + condition (VideoExtendCondition): condition object + condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. + condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. + cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config + gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W + condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + sigma (Tensor): noise level for the generation region + Returns: + VideoExtendCondition: updated condition object + condition_video_augment_sigma: sigma for the condition region, feed to the network + augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W + + """ + + if cfg_video_cond_bool.apply_corruption_to_condition_region == "noise_with_sigma": + # Training only, sample sigma for the condition region + augment_sigma, _ = self.draw_augment_sigma_and_epsilon( + gt_latent.shape, + condition, + cfg_video_cond_bool.augment_sigma_sample_p_mean, + cfg_video_cond_bool.augment_sigma_sample_p_std, + cfg_video_cond_bool.augment_sigma_sample_multiplier, + ) + noise = torch.randn(*gt_latent.shape, **self.tensor_kwargs) + + elif cfg_video_cond_bool.apply_corruption_to_condition_region == "noise_with_sigma_fixed": + # Inference only, use fixed sigma for the condition region + log.debug( + f"condition_video_augment_sigma_in_inference={condition_video_augment_sigma_in_inference}, sigma={sigma.flatten()[0]}" + ) + assert ( + condition_video_augment_sigma_in_inference is not None + ), "condition_video_augment_sigma_in_inference should be provided" + augment_sigma = condition_video_augment_sigma_in_inference + + if augment_sigma >= sigma.flatten()[0]: + # This is a inference trick! If the sampling sigma is smaller than the augment sigma, we will start denoising the condition region together. + # This is achieved by setting all region as `generation`, i.e. value=0 + log.debug("augment_sigma larger than sigma or other frame, remove condition") + condition.condition_video_indicator = condition.condition_video_indicator * 0 + + augment_sigma = torch.tensor([augment_sigma], **self.tensor_kwargs) + + # Inference, use fixed seed + noise = misc.arch_invariant_rand( + gt_latent.shape, + torch.float32, + self.tensor_kwargs["device"], + seed_inference, + ) + else: + raise ValueError(f"does not support {cfg_video_cond_bool.apply_corruption_to_condition_region}") + + # Now apply the augment_sigma to the gt_latent + + augment_latent = gt_latent + noise * augment_sigma.view(-1, 1, 1, 1, 1) + _, _, c_in_augment, c_noise_augment = self.scaling(sigma=augment_sigma) + + if cfg_video_cond_bool.condition_on_augment_sigma: # model takes augment_sigma as input + if condition.condition_video_indicator.sum() > 0: # has condition frames + condition.condition_video_augment_sigma = c_noise_augment + else: # no condition frames + condition.condition_video_augment_sigma = torch.zeros_like(c_noise_augment) + + # Multiply the whole latent with c_in_augment + augment_latent_cin = batch_mul(augment_latent, c_in_augment) + + # Since the whole latent will multiply with c_in later, we devide the value to cancel the effect + _, _, c_in, _ = self.scaling(sigma=sigma) + augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in) + + return condition, augment_latent_cin + + def drop_out_condition_region( + self, augment_latent: Tensor, noise_x: Tensor, cfg_video_cond_bool: VideoCondBoolConfig + ) -> Tensor: + """Use for CFG on input frames, we drop out the conditional region + There are two option: + 1. when we dropout, we set the region to be zero + 2. when we dropout, we set the region to be noise_x + """ + # Unconditional case, use for cfg + if cfg_video_cond_bool.cfg_unconditional_type == "zero_condition_region_condition_mask": + # Set the condition location input to be zero + augment_latent_drop = torch.zeros_like(augment_latent) + elif cfg_video_cond_bool.cfg_unconditional_type == "noise_x_condition_region": + # Set the condition location input to be noise_x, i.e., same as base model training + augment_latent_drop = noise_x + else: + raise NotImplementedError( + f"cfg_unconditional_type {cfg_video_cond_bool.cfg_unconditional_type} not implemented" + ) + return augment_latent_drop + + def denoise( + self, + noise_x: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_video_augment_sigma_in_inference: float = 0.001, + seed_inference: int = 1, + ) -> VideoDenoisePrediction: + """ + Denoise the noisy input tensor. + + Args: + noise_x (Tensor): Noisy input tensor. + sigma (Tensor): Noise level. + condition (VideoExtendCondition): Condition for denoising. + condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + + Returns: + Tensor: Denoised output tensor. + """ + if condition.data_type == DataType.IMAGE: + pred = super().denoise(noise_x, sigma, condition) + log.debug(f"hit image denoise, noise_x shape {noise_x.shape}, sigma shape {sigma.shape}", rank0_only=False) + return VideoDenoisePrediction( + x0=pred.x0, + eps=pred.eps, + logvar=pred.logvar, + xt=noise_x, + ) + else: + assert ( + condition.gt_latent is not None + ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" + gt_latent = condition.gt_latent + cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool + + condition_latent = gt_latent + + if cfg_video_cond_bool.normalize_condition_latent: + condition_latent = normalize_condition_latent(condition_latent) + + # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed + condition, augment_latent = self.augment_conditional_latent_frames( + condition, + cfg_video_cond_bool, + condition_latent, + condition_video_augment_sigma_in_inference, + sigma, + seed_inference=seed_inference, + ) + condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] + if parallel_state.get_context_parallel_world_size() > 1: + cp_group = parallel_state.get_context_parallel_group() + condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) + augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) + gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) + + if not condition.video_cond_bool: + # Unconditional case, drop out the condition region + augment_latent = self.drop_out_condition_region(augment_latent, noise_x, cfg_video_cond_bool) + + # Compose the model input with condition region (augment_latent) and generation region (noise_x) + new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x + # Call the abse model + denoise_pred = super().denoise(new_noise_xt, sigma, condition) + + x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 + if cfg_video_cond_bool.compute_loss_for_condition_region: + # We also denoise the conditional region + x0_pred = denoise_pred.x0 + else: + x0_pred = x0_pred_replaced + + return VideoDenoisePrediction( + x0=x0_pred, + eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), + logvar=denoise_pred.logvar, + net_in=batch_mul(1.0 / torch.sqrt(self.sigma_data**2 + sigma**2), new_noise_xt), + net_x0_pred=denoise_pred.x0, + xt=new_noise_xt, + x0_pred_replaced=x0_pred_replaced, + ) + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Union[torch.Tensor, None] = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + return_noise: bool = False, + ) -> Tensor | Tuple[Tensor, Tensor]: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. + If this feature is stablized, we could consider to move this function to the base model. + + Args: + condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. + num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half + + add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames + return_noise (bool): return the initial noise or not, used for ODE pairs generation + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if is_image_batch: + log.debug("image batch, call base model generate_samples_from_batch") + return super().generate_samples_from_batch( + data_batch, + guidance=guidance, + seed=seed, + state_shape=state_shape, + n_sample=n_sample, + is_negative_prompt=is_negative_prompt, + num_steps=num_steps, + ) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + else: + log.debug(f"Default Video state shape is used. {self.state_shape}") + state_shape = self.state_shape + + assert condition_latent is not None, "condition_latent should be provided" + + x0_fn = self.get_x0_fn_from_batch_with_condition_latent( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + add_input_frames_guidance=add_input_frames_guidance, + seed_inference=seed, # Use for noise of augment sigma + ) + + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), torch.float32, self.tensor_kwargs["device"], seed + ) + * self.sde.sigma_max + ) + if self.net.is_context_parallel_enabled: + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + + samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) + if self.net.is_context_parallel_enabled: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + if return_noise: + if self.net.is_context_parallel_enabled: + x_sigma_max = cat_outputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + return samples, x_sigma_max / self.sde.sigma_max + + return samples + + def get_x0_fn_from_batch_with_condition_latent( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + condition_latent: torch.Tensor = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + seed_inference: int = 1, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + Different from the base model, this function support condition latent as input, it will add the condition information into the condition and uncondition object. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. + - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + - add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + condition = self.add_condition_pose(data_batch, condition) + + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + uncondition = self.add_condition_pose(data_batch, uncondition) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None + ) -> VideoExtendCondition: + """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. + condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. + condition_video_input_mask will be concat with the input for the network. + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + if self.config.conditioner.video_cond_bool.condition_location == "first_n": + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + elif self.config.conditioner.video_cond_bool.condition_location == "first_random_n": + # Only in training + num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max + assert ( + num_condition_t_max <= T + ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" + assert num_condition_t_max >= self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min + num_condition_t = torch.randint( + self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min, + num_condition_t_max + 1, + (1,), + ).item() + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + elif self.config.conditioner.video_cond_bool.condition_location == "random": + # Only in training + condition_rate = self.config.conditioner.video_cond_bool.random_conditon_rate + flag = torch.ones(1, 1, T, 1, 1, device=latent_state.device).type(latent_dtype) * condition_rate + condition_video_indicator = torch.bernoulli(flag).type(latent_dtype).to(latent_state.device) + else: + raise NotImplementedError( + f"condition_location {self.config.conditioner.video_cond_bool.condition_location} not implemented; training={self.training}" + ) + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + # See design doc section (Implementation detail A.1 and A.2) for visualization + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition + + def add_condition_pose(self, data_batch: Dict, condition: VideoExtendCondition) -> VideoExtendCondition: + """Add pose condition to the condition object. For camera control model + Args: + data_batch (Dict): data batch, with key "plucker_embeddings", in shape B,T,C,H,W + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + assert ( + "plucker_embeddings" in data_batch or "plucker_embeddings_downsample" in data_batch.keys() + ), f"plucker_embeddings should be in data_batch. only find {data_batch.keys()}" + plucker_embeddings = ( + data_batch["plucker_embeddings"] + if "plucker_embeddings_downsample" not in data_batch.keys() + else data_batch["plucker_embeddings_downsample"] + ) + condition.condition_video_pose = rearrange(plucker_embeddings, "b t c h w -> b c t h w").contiguous() + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition + + def sample_tokens_start_from_p_or_i(self, latent_state: torch.Tensor) -> torch.Tensor: + """Sample the PPP... from the IPPP... sequence, only for video sequence + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + Returns: + torch.Tensor: sampled PPP tensor in shape B,C,T,H,W + """ + B, C, T, H, W = latent_state.shape + latent_dtype = latent_state.dtype + T_target = self.state_shape[1] + latent_state_sample = torch.zeros((B, C, T_target, H, W), dtype=latent_dtype, device=latent_state.device) + t_start = torch.randint(0, T - T_target + 1, (1,)) + # broadcast to other device + latent_state_sample = latent_state[:, :, t_start : t_start + T_target].contiguous() + if parallel_state.is_initialized(): + latent_state_sample = _broadcast(latent_state_sample, to_tp=True, to_cp=True) + + return latent_state_sample + + +@diffusion_fsdp_class_decorator +class FSDPExtendDiffusionModel(ExtendDiffusionModel): + pass diff --git a/cosmos_transfer1/diffusion/training/models/extend_model_multiview.py b/cosmos_transfer1/diffusion/training/models/extend_model_multiview.py new file mode 100644 index 00000000..02f9a9a7 --- /dev/null +++ b/cosmos_transfer1/diffusion/training/models/extend_model_multiview.py @@ -0,0 +1,448 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Tuple, Union + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_transfer1.diffusion.functional.batch_ops import batch_mul +from cosmos_transfer1.diffusion.conditioner import DataType, VideoExtendCondition +from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_transfer1.diffusion.training.models.extend_model import ( + ExtendDiffusionModel, + VideoDenoisePrediction, + normalize_condition_latent, +) +from cosmos_transfer1.diffusion.training.models.model import DiffusionModel, broadcast_condition +from cosmos_transfer1.diffusion.training.models.model_image import CosmosCondition, diffusion_fsdp_class_decorator +from cosmos_transfer1.utils import log + +from cosmos_transfer1.diffusion.config.base.conditioner import VideoCondBoolConfig + + +class MultiviewExtendDiffusionModel(ExtendDiffusionModel): + def __init__(self, config): + super().__init__(config) + self.n_views = config.n_views + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + state = rearrange(state, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + encoded_state = self.vae.encode(state) + encoded_state = rearrange(encoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) * self.sigma_data + return encoded_state + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + decoded_state = self.vae.decode(latent / self.sigma_data) + decoded_state = rearrange(decoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + return decoded_state + + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, torch.Tensor], + x0_from_data_batch: torch.Tensor, + x0: torch.Tensor, + condition: CosmosCondition, + epsilon: torch.Tensor, + sigma: torch.Tensor, + ): + if self.is_image_batch(data_batch): + # Turn off CP + self.net.disable_context_parallel() + else: + if parallel_state.is_initialized(): + if parallel_state.get_context_parallel_world_size() > 1: + # Turn on CP + cp_group = parallel_state.get_context_parallel_group() + self.net.enable_context_parallel(cp_group) + log.debug("[CP] Split x0 and epsilon") + + x0 = rearrange(x0, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + epsilon = rearrange(epsilon, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + + x0 = split_inputs_cp(x=x0, seq_dim=2, cp_group=self.net.cp_group) + epsilon = split_inputs_cp(x=epsilon, seq_dim=2, cp_group=self.net.cp_group) + + x0 = rearrange(x0, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + epsilon = rearrange(epsilon, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + output_batch, kendall_loss, pred_mse, edm_loss = super( + DiffusionModel, self + ).compute_loss_with_epsilon_and_sigma(data_batch, x0_from_data_batch, x0, condition, epsilon, sigma) + if not self.is_image_batch(data_batch): + if self.loss_reduce == "sum" and parallel_state.get_context_parallel_world_size() > 1: + kendall_loss *= parallel_state.get_context_parallel_world_size() + + return output_batch, kendall_loss, pred_mse, edm_loss + + def denoise( + self, + noise_x: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_video_augment_sigma_in_inference: float = 0.001, + ) -> VideoDenoisePrediction: + """ + Denoise the noisy input tensor. + + Args: + noise_x (Tensor): Noisy input tensor. + sigma (Tensor): Noise level. + condition (VideoExtendCondition): Condition for denoising. + condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + + Returns: + Tensor: Denoised output tensor. + """ + if condition.data_type == DataType.IMAGE: + pred = super(DiffusionModel, self).denoise(noise_x, sigma, condition) + log.debug(f"hit image denoise, noise_x shape {noise_x.shape}, sigma shape {sigma.shape}", rank0_only=False) + return VideoDenoisePrediction( + x0=pred.x0, + eps=pred.eps, + logvar=pred.logvar, + xt=noise_x, + ) + else: + assert ( + condition.gt_latent is not None + ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" + gt_latent = condition.gt_latent + cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool + + condition_latent = gt_latent + + if cfg_video_cond_bool.normalize_condition_latent: + condition_latent = normalize_condition_latent(condition_latent) + + # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed + condition, augment_latent = self.augment_conditional_latent_frames( + condition, cfg_video_cond_bool, condition_latent, condition_video_augment_sigma_in_inference, sigma + ) + condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] + if parallel_state.get_context_parallel_world_size() > 1: + cp_group = parallel_state.get_context_parallel_group() + + condition_video_indicator = rearrange( + condition_video_indicator, "B C (V T) H W -> (B V) C T H W", V=self.n_views + ) + augment_latent = rearrange(augment_latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + gt_latent = rearrange(gt_latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + + condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) + augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) + gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) + + condition_video_indicator = rearrange( + condition_video_indicator, "(B V) C T H W -> B C (V T) H W", V=self.n_views + ) + augment_latent = rearrange(augment_latent, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + gt_latent = rearrange(gt_latent, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + if not condition.video_cond_bool: + # Unconditional case, drop out the condition region + augment_latent = self.drop_out_condition_region(augment_latent, noise_x, cfg_video_cond_bool) + # Compose the model input with condition region (augment_latent) and generation region (noise_x) + new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x + # Call the abse model + + denoise_pred = super(DiffusionModel, self).denoise(new_noise_xt, sigma, condition) + + x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 + if cfg_video_cond_bool.compute_loss_for_condition_region: + # We also denoise the conditional region + x0_pred = denoise_pred.x0 + else: + x0_pred = x0_pred_replaced + + return VideoDenoisePrediction( + x0=x0_pred, + eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), + logvar=denoise_pred.logvar, + net_in=batch_mul(1.0 / torch.sqrt(self.sigma_data**2 + sigma**2), new_noise_xt), + net_x0_pred=denoise_pred.x0, + xt=new_noise_xt, + x0_pred_replaced=x0_pred_replaced, + ) + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None + ) -> VideoExtendCondition: + """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. + condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. + condition_video_input_mask will be concat with the input for the network. + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + + condition_video_indicator = rearrange( + condition_video_indicator, "B C (V T) H W -> (B V) C T H W", V=self.n_views + ) + if self.config.conditioner.video_cond_bool.condition_location == "first_n": + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + elif self.config.conditioner.video_cond_bool.condition_location == "first_random_n": + # Only in training + num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max + assert ( + num_condition_t_max <= T + ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" + num_condition_t = torch.randint(0, num_condition_t_max + 1, (1,)).item() + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + else: + raise NotImplementedError( + f"condition_location {self.config.conditioner.video_cond_bool.condition_location} not implemented; training={self.training}" + ) + + condition_video_indicator = rearrange( + condition_video_indicator, "(B V) C T H W -> B C (V T) H W", V=self.n_views + ) + + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + # See design doc section (Implementation detail A.1 and A.2) for visualization + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition + + def get_x0_fn_from_batch_with_condition_latent( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + condition_latent: torch.Tensor = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + guidance_other: Union[float, None] = None, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + Different from the base model, this function support condition latent as input, it will add the condition information into the condition and uncondition object. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. + - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + - add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + condition = self.add_condition_pose(data_batch, condition) + + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + uncondition = self.add_condition_pose(data_batch, uncondition) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + if guidance_other is not None: # and guidance_other != guidance: + import copy + + assert not parallel_state.is_initialized(), "Parallel state not supported with two guidances." + condition_other = copy.deepcopy(uncondition) + condition_other.trajectory = condition.trajectory + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + cond_other_x0 = self.denoise( + noise_x, + sigma, + condition_other, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + return cond_x0 + guidance * (cond_x0 - uncond_x0) + guidance_other * (cond_other_x0 - uncond_x0) + + else: + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Union[torch.Tensor, None] = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + guidance_other: Union[float, None] = None, + ) -> Tensor: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. + If this feature is stablized, we could consider to move this function to the base model. + + Args: + condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. + num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half + + add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if is_image_batch: + log.debug("image batch, call base model generate_samples_from_batch") + return super().generate_samples_from_batch( + data_batch, + guidance=guidance, + seed=seed, + state_shape=state_shape, + n_sample=n_sample, + is_negative_prompt=is_negative_prompt, + num_steps=num_steps, + ) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + else: + log.debug(f"Default Video state shape is used. {self.state_shape}") + state_shape = self.state_shape + + assert condition_latent is not None, "condition_latent should be provided" + x0_fn = self.get_x0_fn_from_batch_with_condition_latent( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + add_input_frames_guidance=add_input_frames_guidance, + guidance_other=guidance_other, + ) + + generator = torch.Generator(device=self.tensor_kwargs["device"]) + generator.manual_seed(seed) + x_sigma_max = ( + torch.randn(n_sample, *state_shape, **self.tensor_kwargs, generator=generator) * self.sde.sigma_max + ) + + if self.net.is_context_parallel_enabled: + x_sigma_max = rearrange(x_sigma_max, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + x_sigma_max = rearrange(x_sigma_max, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) + if self.net.is_context_parallel_enabled: + samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + return samples + + +@diffusion_fsdp_class_decorator +class FSDPExtendDiffusionModel(MultiviewExtendDiffusionModel): + pass diff --git a/cosmos_transfer1/diffusion/training/models/model.py b/cosmos_transfer1/diffusion/training/models/model.py new file mode 100644 index 00000000..e6aef604 --- /dev/null +++ b/cosmos_transfer1/diffusion/training/models/model.py @@ -0,0 +1,660 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union + +import amp_C +import torch +from apex.multi_tensor_apply import multi_tensor_applier +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.distributed import broadcast_object_list, get_process_group_ranks +from torch.distributed.utils import _verify_param_shape_across_processes + +from cosmos_transfer1.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS +from cosmos_transfer1.diffusion.conditioner import BaseVideoCondition, DataType +from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_transfer1.diffusion.training.models.model_image import CosmosCondition +from cosmos_transfer1.diffusion.training.models.model_image import DiffusionModel as ImageModel +from cosmos_transfer1.diffusion.training.models.model_image import diffusion_fsdp_class_decorator +from cosmos_transfer1.utils import distributed, log, misc + +l2_norm_impl = amp_C.multi_tensor_l2norm +multi_tensor_scale_impl = amp_C.multi_tensor_scale + +# key to check if the video data is normalized or image data is converted to video data +# to avoid apply normalization or augment image dimension multiple times +# It is due to we do not have normalization and augment image dimension in the dataloader and move it to the model +IS_PREPROCESSED_KEY = "is_preprocessed" + + +def robust_broadcast(tensor: torch.Tensor, src: int, pg, is_check_shape: bool = False) -> torch.Tensor: + """ + Perform a robust broadcast operation that works regardless of tensor shapes on different ranks. + + Args: + tensor (torch.Tensor): The tensor to broadcast (on src rank) or receive (on other ranks). + src (int): The source rank for the broadcast. Defaults to 0. + + Returns: + torch.Tensor: The broadcasted tensor on all ranks. + """ + # First, broadcast the shape of the tensor + if distributed.get_rank() == src: + shape = torch.tensor(tensor.shape).cuda() + else: + shape = torch.empty(tensor.dim(), dtype=torch.long).cuda() + if is_check_shape: + _verify_param_shape_across_processes(pg, [shape]) + torch.distributed.broadcast(shape, src, group=pg) + + # Resize the tensor on non-src ranks if necessary + if distributed.get_rank() != src: + tensor = tensor.new_empty(shape.tolist()).type_as(tensor) + + # Now broadcast the tensor data + torch.distributed.broadcast(tensor, src, group=pg) + + return tensor + + +def _broadcast(item: torch.Tensor | str | None, to_tp: bool = True, to_cp: bool = True) -> torch.Tensor | str | None: + """ + Broadcast the item from the minimum rank in the specified group(s). + Since global rank = tp_rank + cp_rank * tp_size + ... + First broadcast in the tp_group and then in the cp_group will + ensure that the item is broadcasted across ranks in cp_group and tp_group. + + Parameters: + - item: The item to broadcast (can be a torch.Tensor, str, or None). + - to_tp: Whether to broadcast to the tensor model parallel group. + - to_cp: Whether to broadcast to the context parallel group. + """ + if not parallel_state.is_initialized(): + return item + tp_group = parallel_state.get_tensor_model_parallel_group() + cp_group = parallel_state.get_context_parallel_group() + + to_tp = to_tp and parallel_state.get_tensor_model_parallel_world_size() > 1 + to_cp = to_cp and parallel_state.get_context_parallel_world_size() > 1 + + if to_tp: + min_tp_rank = min(get_process_group_ranks(tp_group)) + + if to_cp: + min_cp_rank = min(get_process_group_ranks(cp_group)) + + if isinstance(item, torch.Tensor): # assume the device is cuda + # log.info(f"{item.shape}", rank0_only=False) + if to_tp: + # torch.distributed.broadcast(item, min_tp_rank, group=tp_group) + item = robust_broadcast(item, min_tp_rank, tp_group) + if to_cp: + # torch.distributed.broadcast(item, min_cp_rank, group=cp_group) + item = robust_broadcast(item, min_cp_rank, cp_group) + elif item is not None: + broadcastable_list = [item] + if to_tp: + # log.info(f"{broadcastable_list}", rank0_only=False) + broadcast_object_list(broadcastable_list, min_tp_rank, group=tp_group) + if to_cp: + broadcast_object_list(broadcastable_list, min_cp_rank, group=cp_group) + + item = broadcastable_list[0] + return item + + +def broadcast_condition(condition: BaseVideoCondition, to_tp: bool = True, to_cp: bool = True) -> BaseVideoCondition: + condition_kwargs = {} + for k, v in condition.to_dict().items(): + if isinstance(v, torch.Tensor): + assert not v.requires_grad, f"{k} requires gradient. the current impl does not support it" + condition_kwargs[k] = _broadcast(v, to_tp=to_tp, to_cp=to_cp) + condition = type(condition)(**condition_kwargs) + return condition + + +class DiffusionModel(ImageModel): + def __init__(self, config): + super().__init__(config) + # Initialize trained_data_record with defaultdict, key: image, video, iteration + self.trained_data_record = { + "image": 0, + "video": 0, + "iteration": 0, + } + if parallel_state.is_initialized(): + self.data_parallel_size = parallel_state.get_data_parallel_world_size() + else: + self.data_parallel_size = 1 + + if self.config.adjust_video_noise: + self.video_noise_multiplier = math.sqrt(self.state_shape[1]) + else: + self.video_noise_multiplier = 1.0 + + def setup_data_key(self) -> None: + self.input_data_key = self.config.input_data_key # by default it is video key for Video diffusion model + self.input_image_key = self.config.input_image_key + + def is_image_batch(self, data_batch: dict[str, Tensor]) -> bool: + """We hanlde two types of data_batch. One comes from a joint_dataloader where "dataset_name" can be used to differenciate image_batch and video_batch. + Another comes from a dataloader which we by default assumes as video_data for video model training. + """ + is_image = self.input_image_key in data_batch + is_video = self.input_data_key in data_batch + assert ( + is_image != is_video + ), "Only one of the input_image_key or input_data_key should be present in the data_batch." + return is_image + + def draw_training_sigma_and_epsilon(self, size: int, condition: BaseVideoCondition) -> Tensor: + sigma_B, epsilon = super().draw_training_sigma_and_epsilon(size, condition) + is_video_batch = condition.data_type == DataType.VIDEO + multiplier = self.video_noise_multiplier if is_video_batch else 1 + sigma_B = _broadcast(sigma_B * multiplier, to_tp=True, to_cp=is_video_batch) + epsilon = _broadcast(epsilon, to_tp=True, to_cp=is_video_batch) + return sigma_B, epsilon + + @torch.no_grad() + def validation_step( + self, data: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + save generated videos + """ + raw_data, x0, condition = self.get_data_and_condition(data) + guidance = data["guidance"] + data = misc.to(data, **self.tensor_kwargs) + sample = self.generate_samples_from_batch( + data, + guidance=guidance, + # make sure no mismatch and also works for cp + state_shape=x0.shape[1:], + n_sample=x0.shape[0], + ) + sample = self.decode(sample) + gt = raw_data + caption = data["ai_caption"] + return {"gt": gt, "result": sample, "caption": caption}, torch.tensor([0]).to(**self.tensor_kwargs) + + def training_step(self, data_batch: Dict[str, Tensor], iteration: int) -> Tuple[Dict[str, Tensor] | Tensor]: + input_key = self.input_data_key # by default it is video key + if self.is_image_batch(data_batch): + input_key = self.input_image_key + batch_size = data_batch[input_key].shape[0] + self.trained_data_record["image" if self.is_image_batch(data_batch) else "video"] += ( + batch_size * self.data_parallel_size + ) + self.trained_data_record["iteration"] += 1 + return super().training_step(data_batch, iteration) + + def state_dict(self) -> Dict[str, Any]: + state_dict = super().state_dict() + state_dict["trained_data_record"] = self.trained_data_record + return state_dict + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): + if "trained_data_record" in state_dict and hasattr(self, "trained_data_record"): + trained_data_record = state_dict.pop("trained_data_record") + if trained_data_record: + assert set(trained_data_record.keys()) == set(self.trained_data_record.keys()) + for k, v in trained_data_record.items(): + self.trained_data_record[k] = v + else: + log.warning("trained_data_record not found in the state_dict.") + return super().load_state_dict(state_dict, strict, assign) + + def _normalize_video_databatch_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: + """ + Normalizes video data in-place on a CUDA device to reduce data loading overhead. + + This function modifies the video data tensor within the provided data_batch dictionary + in-place, scaling the uint8 data from the range [0, 255] to the normalized range [-1, 1]. + + Warning: + A warning is issued if the data has not been previously normalized. + + Args: + data_batch (dict[str, Tensor]): A dictionary containing the video data under a specific key. + This tensor is expected to be on a CUDA device and have dtype of torch.uint8. + + Side Effects: + Modifies the 'input_data_key' tensor within the 'data_batch' dictionary in-place. + + Note: + This operation is performed directly on the CUDA device to avoid the overhead associated + with moving data to/from the GPU. Ensure that the tensor is already on the appropriate device + and has the correct dtype (torch.uint8) to avoid unexpected behaviors. + """ + input_key = self.input_data_key if input_key is None else input_key + # only handle video batch + if input_key in data_batch: + # Check if the data has already been normalized and avoid re-normalizing + if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: + assert torch.is_floating_point(data_batch[input_key]), "Video data is not in float format." + assert torch.all( + (data_batch[input_key] >= -1.0001) & (data_batch[input_key] <= 1.0001) + ), f"Video data is not in the range [-1, 1]. get data range [{data_batch[input_key].min()}, {data_batch[input_key].max()}]" + else: + assert data_batch[input_key].dtype == torch.uint8, "Video data is not in uint8 format." + data_batch[input_key] = data_batch[input_key].to(**self.tensor_kwargs) / 127.5 - 1.0 + data_batch[IS_PREPROCESSED_KEY] = True + + def _augment_image_dim_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: + input_key = self.input_image_key if input_key is None else input_key + if input_key in data_batch: + # Check if the data has already been augmented and avoid re-augmenting + if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: + assert ( + data_batch[input_key].shape[2] == 1 + ), f"Image data is claimed be augmented while its shape is {data_batch[input_key].shape}" + return + else: + data_batch[input_key] = rearrange(data_batch[input_key], "b c h w -> b c 1 h w").contiguous() + data_batch[IS_PREPROCESSED_KEY] = True + + def get_data_and_condition(self, data_batch: dict[str, Tensor]) -> Tuple[Tensor, BaseVideoCondition]: + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + input_key = self.input_data_key # by default it is video key + is_image_batch = self.is_image_batch(data_batch) + is_video_batch = not is_image_batch + + # Broadcast data and condition across TP and CP groups. + # sort keys to make sure the order is same, IMPORTANT! otherwise, nccl will hang! + local_keys = sorted(list(data_batch.keys())) + # log.critical(f"all keys {local_keys}", rank0_only=False) + for key in local_keys: + data_batch[key] = _broadcast(data_batch[key], to_tp=True, to_cp=is_video_batch) + + if is_image_batch: + input_key = self.input_image_key + + # Latent state + raw_state = data_batch[input_key] + latent_state = self.encode(raw_state).contiguous() + + # Condition + condition = self.conditioner(data_batch) + if is_image_batch: + condition.data_type = DataType.IMAGE + else: + condition.data_type = DataType.VIDEO + + # VAE has randomness. CP/TP group should have the same encoded output. + + latent_state = _broadcast(latent_state, to_tp=True, to_cp=is_video_batch) + condition = broadcast_condition(condition, to_tp=True, to_cp=is_video_batch) + + return raw_state, latent_state, condition + + def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: + super().on_train_start(memory_format) + if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: + sequence_parallel = getattr(parallel_state, "sequence_parallel", False) + if sequence_parallel: + self.net.enable_sequence_parallel() + + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, torch.Tensor], + x0_from_data_batch: torch.Tensor, + x0: torch.Tensor, + condition: CosmosCondition, + epsilon: torch.Tensor, + sigma: torch.Tensor, + ): + if self.is_image_batch(data_batch): + # Turn off CP + self.net.disable_context_parallel() + else: + if parallel_state.is_initialized(): + if parallel_state.get_context_parallel_world_size() > 1: + # Turn on CP + cp_group = parallel_state.get_context_parallel_group() + self.net.enable_context_parallel(cp_group) + log.debug("[CP] Split x0 and epsilon") + x0 = split_inputs_cp(x=x0, seq_dim=2, cp_group=self.net.cp_group) + epsilon = split_inputs_cp(x=epsilon, seq_dim=2, cp_group=self.net.cp_group) + + output_batch, kendall_loss, pred_mse, edm_loss = super().compute_loss_with_epsilon_and_sigma( + data_batch, x0_from_data_batch, x0, condition, epsilon, sigma + ) + if not self.is_image_batch(data_batch): + if self.loss_reduce == "sum" and parallel_state.get_context_parallel_world_size() > 1: + kendall_loss *= parallel_state.get_context_parallel_world_size() + + return output_batch, kendall_loss, pred_mse, edm_loss + + def get_x0_fn_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + if "guided_image" in data_batch: + # replacement trick that enables inpainting with base model + assert "guided_mask" in data_batch, "guided_mask should be in data_batch if guided_image is present" + guide_image = data_batch["guided_image"] + guide_mask = data_batch["guided_mask"] + raw_x0 = guide_mask * guide_image + (1 - guide_mask) * raw_x0 + return raw_x0 + + return x0_fn + + def get_x_from_clean( + self, + in_clean_img: torch.Tensor, + sigma_max: float | None, + seed: int = 1, + ) -> Tensor: + """ + in_clean_img (torch.Tensor): input clean image for image-to-image/video-to-video by adding noise then denoising + sigma_max (float): maximum sigma applied to in_clean_image for image-to-image/video-to-video + """ + if in_clean_img is None: + return None + generator = torch.Generator(device=self.tensor_kwargs["device"]) + generator.manual_seed(seed) + noise = torch.randn(*in_clean_img.shape, **self.tensor_kwargs, generator=generator) + if sigma_max is None: + sigma_max = self.sde.sigma_max + x_sigma_max = in_clean_img + noise * sigma_max + return x_sigma_max + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + x_sigma_max: Optional[torch.Tensor] = None, + sigma_max: float | None = None, + return_noise: bool = False, + ) -> Tensor | Tuple[Tensor, Tensor]: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + guidance (float): guidance weights + seed (int): random seed + state_shape (tuple): shape of the state, default to self.state_shape if not provided + n_sample (int): number of samples to generate + is_negative_prompt (bool): use negative prompt t5 in uncondition if true + num_steps (int): number of steps for the diffusion process + solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) + return_noise (bool): return the initial noise or not, used for ODE pairs generation + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + + x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt) + + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * self.sde.sigma_max + ) + + if self.net.is_context_parallel_enabled: + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + + samples = self.sampler( + x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max, solver_option=solver_option + ) + if self.net.is_context_parallel_enabled: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + if return_noise: + if self.net.is_context_parallel_enabled: + x_sigma_max = cat_outputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + return samples, x_sigma_max / self.sde.sigma_max + + return samples + + def on_after_backward(self, iteration: int = 0): + finalize_model_grads([self]) + + def get_grad_norm( + self, + norm_type: Union[int, float] = 2, + filter_fn: Callable[[str, torch.nn.Parameter], bool] | None = None, + ) -> float: + """Calculate the norm of gradients, handling model parallel parameters. + + This function is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ + with added functionality to handle model parallel parameters. + + Args: + norm_type (float or int): Type of norm to use. Can be 2 for L2 norm. + 'inf' for infinity norm is not supported. + filter_fn (callable, optional): Function to filter parameters for norm calculation. + Takes parameter name and parameter as input, returns True if this parameter is sharded else False. + + Returns: + float: Total norm of the parameters (viewed as a single vector). + + Note: + - Uses NVIDIA's multi-tensor applier for efficient norm calculation. + - Handles both model parallel and non-model parallel parameters separately. + - Currently only supports L2 norm (norm_type = 2). + """ + # Get model parallel group if parallel state is initialized + if parallel_state.is_initialized(): + model_parallel_group = parallel_state.get_model_parallel_group() + else: + model_parallel_group = None + + # Default filter function to identify tensor parallel parameters + if filter_fn is None: + + def is_tp(name, param): + return ( + any(key in name for key in ["to_q.0", "to_k.0", "to_v.0", "to_out.0", "layer1", "layer2"]) + and "_extra_state" not in name + ) + + filter_fn = is_tp + + # Separate gradients into model parallel and non-model parallel + without_mp_grads_for_norm = [] + with_mp_grads_for_norm = [] + for name, param in self.named_parameters(): + if param.grad is not None: + if filter_fn(name, param): + with_mp_grads_for_norm.append(param.grad.detach()) + else: + without_mp_grads_for_norm.append(param.grad.detach()) + + # Only L2 norm is currently supported + if norm_type != 2.0: + raise NotImplementedError(f"Norm type {norm_type} is not supported. Only L2 norm (2.0) is implemented.") + + # Calculate L2 norm using NVIDIA's multi-tensor applier + dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") + + # Calculate norm for non-model parallel gradients + without_mp_grad_norm = torch.tensor([0], dtype=torch.float, device="cuda") + if without_mp_grads_for_norm: + without_mp_grad_norm, _ = multi_tensor_applier( + l2_norm_impl, + dummy_overflow_buf, + [without_mp_grads_for_norm], + False, # no per-parameter norm + ) + + # Calculate norm for model parallel gradients + with_mp_grad_norm = torch.tensor([0], dtype=torch.float, device="cuda") + if with_mp_grads_for_norm: + with_mp_grad_norm, _ = multi_tensor_applier( + l2_norm_impl, + dummy_overflow_buf, + [with_mp_grads_for_norm], + False, # no per-parameter norm + ) + + # Square the norms as we'll be summing across model parallel GPUs + total_without_mp_norm = without_mp_grad_norm**2 + total_with_mp_norm = with_mp_grad_norm**2 + + # Sum across all model-parallel GPUs + torch.distributed.all_reduce(total_with_mp_norm, op=torch.distributed.ReduceOp.SUM, group=model_parallel_group) + + # Combine norms from model parallel and non-model parallel gradients + total_norm = (total_with_mp_norm.item() + total_without_mp_norm.item()) ** 0.5 + + return total_norm + + def clip_grad_norm_(self, max_norm: float): + """ + This function performs gradient clipping to prevent exploding gradients. + It calculates the total norm of the gradients, and if it exceeds the + specified max_norm, scales the gradients down proportionally. + + Args: + max_norm (float): The maximum allowed norm for the gradients. + + Returns: + torch.Tensor: The total norm of the gradients before clipping. + + Note: + This implementation uses NVIDIA's multi-tensor applier for efficiency. + """ + # Collect gradients from all parameters that require gradients + grads = [] + for param in self.parameters(): + if param.grad is not None: + grads.append(param.grad.detach()) + + # Calculate the total norm of the gradients + total_norm = self.get_grad_norm() + + # Compute the clipping coefficient + clip_coeff = max_norm / (total_norm + 1.0e-6) + + # Apply gradient clipping if the total norm exceeds max_norm + if clip_coeff < 1.0: + dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") + # Apply the scaling to the gradients using multi_tensor_applier for efficiency + multi_tensor_applier(multi_tensor_scale_impl, dummy_overflow_buf, [grads, grads], clip_coeff) + + return torch.tensor([total_norm]) + + +def _allreduce_layernorm_grads(model: List[torch.nn.Module]): + """ + All-reduce the following layernorm grads: + - When tensor parallel is enabled, all-reduce grads of QK-layernorm + - When sequence parallel, all-reduce grads of AdaLN, t_embedder, additional_timestamp_embedder, + and affline_norm. + """ + sequence_parallel = getattr(parallel_state, "sequence_parallel", False) + + if parallel_state.get_tensor_model_parallel_world_size() > 1: + grads = [] + for model_chunk in model: + for name, param in model_chunk.named_parameters(): + if not param.requires_grad: + continue + + if "to_q.1" in name or "to_k.1" in name: # TP # Q-layernorm # K-layernorm + grad = param.grad + if grad is not None: + grads.append(grad.data) + + if sequence_parallel: # TP + SP + if ( + "t_embedder" in name + or "adaLN_modulation" in name + or "additional_timestamp_embedder" in name + or "affline_norm" in name + or "input_hint_block" in name + or "zero_blocks" in name + ): + grad = param.grad + if grad is not None: + grads.append(grad.data) + + if grads: + coalesced = _flatten_dense_tensors(grads) + torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group()) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + +def finalize_model_grads(model: List[torch.nn.Module]): + """ + All-reduce layernorm grads for tensor/sequence parallelism. + Reference implementation: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/finalize_model_grads.py#L99 + """ + + _allreduce_layernorm_grads(model) + + +@diffusion_fsdp_class_decorator +class FSDPDiffusionModel(DiffusionModel): + pass diff --git a/cosmos_transfer1/diffusion/training/models/model_ctrl.py b/cosmos_transfer1/diffusion/training/models/model_ctrl.py new file mode 100644 index 00000000..58db6883 --- /dev/null +++ b/cosmos_transfer1/diffusion/training/models/model_ctrl.py @@ -0,0 +1,720 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor +from typing import Callable, Dict, Optional, Tuple, Union, Type, TypeVar + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_transfer1.diffusion.conditioner import DataType, VideoConditionerWithCtrl, CosmosCondition +from cosmos_transfer1.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS +from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_transfer1.diffusion.training.models.model import _broadcast, broadcast_condition +from cosmos_transfer1.diffusion.training.models.model_image import diffusion_fsdp_class_decorator +from cosmos_transfer1.diffusion.training.models.extend_model import ExtendDiffusionModel as ExtendVideoDiffusionModel +from cosmos_transfer1.diffusion.inference.inference_utils import non_strict_load_model, merge_patches_into_video, split_video_into_patches +from cosmos_transfer1.utils import log, misc +from cosmos_transfer1.utils.lazy_config import instantiate + + +T = TypeVar("T") +IS_PREPROCESSED_KEY = "is_preprocessed" + +def ctrlnet_decorator(base_class: Type[T]) -> Type[T]: + class CtrlNetModel(base_class): + def __init__(self, config, fsdp_checkpointer=None): + if fsdp_checkpointer is not None: + return super().__init__(config, fsdp_checkpointer) + else: + return super().__init__(config) + + def build_model(self) -> torch.nn.ModuleDict: + log.info("Start creating base model") + base_model = super().build_model() + # initialize base model + config = self.config + self.load_base_model(base_model) + log.info("Done creating base model") + + log.info("Start creating ctrlnet model") + net = instantiate(self.config.net_ctrl) + conditioner = base_model.conditioner + logvar = base_model.logvar + # initialize controlnet encoder + model = torch.nn.ModuleDict({"net": net, "conditioner": conditioner, "logvar": logvar}) + model.load_state_dict(base_model.state_dict(), strict=False) + + model.base_model = base_model + if not config.finetune_base_model: + model.base_model.requires_grad_(False) + log.critical("Only training ctrlnet model and keeping base model frozen") + else: + log.critical("Also training base model") + log.info("Done creating ctrlnet model") + + self.hint_key = self.config.hint_key["hint_key"] + return model + + @property + def base_net(self): + return self.model.base_model.net + + def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: + super().on_train_start(memory_format) + # self.base_model = self.base_model.to(memory_format=memory_format, **self.tensor_kwargs) + self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) + if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: + if parallel_state.sequence_parallel: + self.base_net.enable_sequence_parallel() + if ( + hasattr(self.config, "use_torch_compile") and self.config.use_torch_compile + ): # compatible with old config + # not tested yet + if torch.__version__ < "2.3": + log.warning( + "torch.compile in Pytorch version older than 2.3 doesn't work well with activation checkpointing.\n" + "It's very likely there will be no significant speedup from torch.compile.\n" + "Please use at least 24.04 Pytorch container, or imaginaire4:v7 container." + ) + self.base_net = torch.compile(self.base_net, dynamic=False, disable=not self.config.use_torch_compile) + + def load_base_model(self, base_model) -> None: + config = self.config + if config.base_load_from is not None: + checkpoint_path = config.base_load_from["load_path"] + else: + checkpoint_path = "" + + if "*" in checkpoint_path: + # there might be better ways to decide if it's a converted tp checkpoint + mp_rank = parallel_state.get_model_parallel_group().rank() + checkpoint_path = checkpoint_path.replace("*", f"{mp_rank}") + + if checkpoint_path: + log.info(f"Loading base model checkpoint (local): {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + log.success(f"Complete loading base model checkpoint (local): {checkpoint_path}") + + if "ema" in state_dict: + # Copy the base model weights from ema model. + log.info("Copying ema to base model") + base_state_dict = {k.replace("-", "."): v for k, v in state_dict["ema"].items()} + elif "model" in state_dict: + # Copy the base model weights from reg model. + log.warning("Using non-EMA base model") + base_state_dict = state_dict["model"] + else: + log.info("Loading from an EMA only model") + base_state_dict = state_dict + try: + base_model.load_state_dict(base_state_dict, strict=False) + except Exception: + log.critical("load model in non-strict mode") + log.critical(non_strict_load_model(base_model, base_state_dict), rank0_only=False) + log.info("Done loading the base model checkpoint.") + + return CtrlNetModel + + +def video_ctrlnet_decorator(base_class: Type[T]) -> Type[T]: + class VideoDiffusionModelWithCtrlWrapper(base_class): + def __init__(self, config): + super().__init__(config) + if hasattr(config, "pixel_corruptor") and config.pixel_corruptor is not None: + self.pixel_corruptor = instantiate(config.pixel_corruptor) + self.pixel_corruptor.to(**self.tensor_kwargs) + else: + self.pixel_corruptor = None + + def get_data_and_condition( + self, data_batch: dict[str, Tensor], **kwargs + ) -> Tuple[Tensor, VideoConditionerWithCtrl]: + # process the control input + hint_key = self.config.hint_key["hint_key"] + is_image_batch = self.is_image_batch(data_batch) + _data = {hint_key: data_batch[hint_key]} + if IS_PREPROCESSED_KEY in data_batch: + _data[IS_PREPROCESSED_KEY] = data_batch[IS_PREPROCESSED_KEY] + if not is_image_batch: + self._normalize_video_databatch_inplace(_data, input_key=hint_key) + # if it is an image batch, the control input is also image + if self.input_image_key in data_batch: + self._augment_image_dim_inplace(_data, input_key=hint_key) + data_batch[hint_key] = _data[hint_key] + # else: + # raise NotImplementedError(f"{self.config.hint_key} is not implemented.") + data_batch["hint_key"] = hint_key + raw_state, latent_state, condition = super().get_data_and_condition(data_batch, **kwargs) + # if not torch.is_grad_enabled() and all(self.config.hint_mask): + use_multicontrol = ( + ("control_weight" in data_batch) + and not isinstance(data_batch["control_weight"], float) + and data_batch["control_weight"].shape[0] > 1 + ) + if use_multicontrol: # encode individual conditions separately + latent_hint = [] + num_conditions = data_batch[data_batch["hint_key"]].size(1) // 3 + for i in range(num_conditions): + cond_mask = [False] * num_conditions + cond_mask[i] = True + latent_hint += [self.encode_latent(data_batch, cond_mask=cond_mask)] + latent_hint = torch.cat(latent_hint) + else: + latent_hint = self.encode_latent(data_batch) + # copied from model.py + is_image_batch = self.is_image_batch(data_batch) + is_video_batch = not is_image_batch + # VAE has randomness. CP/TP group should have the same encoded output. + + latent_hint = _broadcast(latent_hint, to_tp=True, to_cp=is_video_batch) + + # add extra conditions + data_batch["latent_hint"] = latent_hint + setattr(condition, hint_key, latent_hint) + setattr(condition, "base_model", self.model.base_model) + return raw_state, latent_state, condition + + def encode_latent(self, data_batch: dict, cond_mask: list = []) -> torch.Tensor: + x = data_batch[data_batch["hint_key"]] + if torch.is_grad_enabled() and self.pixel_corruptor is not None: + x = self.pixel_corruptor(x) + latent = [] + # control input goes through tokenizer, which always takes 3-input channels + num_conditions = x.size(1) // 3 # input conditions were concatenated along channel dimension + if num_conditions > 1 and self.config.hint_dropout_rate > 0: + if torch.is_grad_enabled(): # during training, randomly dropout some conditions + cond_mask = torch.rand(num_conditions) > self.config.hint_dropout_rate + if not cond_mask.any(): # make sure at least one condition is present + cond_mask[torch.randint(num_conditions, (1,)).item()] = True + elif not cond_mask: # during inference, use hint_mask to indicate which conditions are used + cond_mask = self.config.hint_mask + else: + cond_mask = [True] * num_conditions + for idx in range(0, x.size(1), 3): + x_rgb = x[:, idx : idx + 3] + if self.config.hint_key["grayscale"]: + x_rgb = x_rgb.mean(dim=1, keepdim=True).expand_as(x_rgb) + # if idx == 0: + # x_max = x_rgb + # else: + # x_max = torch.maximum(x_rgb, x_max) + if not cond_mask[idx // 3]: # if the condition is not selected, replace with a black image + x_rgb = torch.zeros_like(x_rgb) + latent.append(self.encode(x_rgb)) + # latent.append(self.encode(x_max)) + latent = torch.cat(latent, dim=1) + return latent + + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, Tensor], + x0_from_data_batch: Tensor, + x0: Tensor, + condition: CosmosCondition, + epsilon: Tensor, + sigma: Tensor, + ): + if self.is_image_batch(data_batch): + # Turn off CP + self.net.disable_context_parallel() + self.base_net.disable_context_parallel() + else: + if parallel_state.get_context_parallel_world_size() > 1: + # Turn on CP + cp_group = parallel_state.get_context_parallel_group() + self.net.enable_context_parallel(cp_group) + self.base_net.enable_context_parallel(cp_group) + log.debug("[CP] Split hint_input") + hint_key = self.config.hint_key["hint_key"] + x_hint_raw = getattr(condition, hint_key) + x_hint = split_inputs_cp(x=x_hint_raw, seq_dim=2, cp_group=self.net.cp_group) + setattr(condition, hint_key, x_hint) + return super().compute_loss_with_epsilon_and_sigma( + data_batch, x0_from_data_batch, x0, condition, epsilon, sigma + ) + + def get_x0_fn_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + condition_latent: torch.Tensor = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + seed_inference: int = 1, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. + - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + # data_batch should be the one processed by self.get_data_and_condition + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + if hasattr(self, "is_extend_model") and self.is_extend_model: + # Add conditions for long video generation. + if self.is_image_batch(data_batch): + condition.data_type = DataType.IMAGE + uncondition.data_type = DataType.IMAGE + else: + if condition_latent is None: + condition_latent = torch.zeros(data_batch["latent_hint"].shape, **self.tensor_kwargs) + num_condition_t = 0 + condition_video_augment_sigma_in_inference = 1000 + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + condition = self.add_condition_pose(data_batch, condition) + + uncondition.video_cond_bool = True # Not do cfg on condition frames + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + uncondition = self.add_condition_pose(data_batch, uncondition) + + # Add extra conditions for ctrlnet. + latent_hint = data_batch["latent_hint"] + hint_key = data_batch["hint_key"] + setattr(condition, hint_key, latent_hint) + if "use_none_hint" in data_batch and data_batch["use_none_hint"]: + setattr(uncondition, hint_key, None) + else: + setattr(uncondition, hint_key, latent_hint) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized() and not self.is_image_batch(data_batch): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + + cp_group = parallel_state.get_context_parallel_group() + latent_hint = getattr(condition, hint_key) + latent_hint = split_inputs_cp(latent_hint, seq_dim=2, cp_group=cp_group) + setattr(condition, hint_key, latent_hint) + if getattr(uncondition, hint_key) is not None: + setattr(uncondition, hint_key, latent_hint) + # else: + # assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + setattr(condition, "base_model", self.model.base_model) + setattr(uncondition, "base_model", self.model.base_model) + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + if self.is_image_batch(data_batch) or not issubclass(base_class, ExtendVideoDiffusionModel): + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + else: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Union[torch.Tensor, None] = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + x_sigma_max: Optional[torch.Tensor] = None, + sigma_max: float | None = None, + return_noise: bool = False, + ) -> Tensor: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. + If this feature is stablized, we could consider to move this function to the base model. + + Args: + condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. + num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half + + return_noise (bool): return the initial noise or not, used for ODE pairs generation. Not used here. Kept for conmpatibility. + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if is_image_batch: + log.debug("image batch, call base model generate_samples_from_batch") + return super().generate_samples_from_batch( + data_batch, + guidance=guidance, + seed=seed, + state_shape=state_shape, + n_sample=n_sample, + is_negative_prompt=is_negative_prompt, + num_steps=num_steps, + ) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + else: + log.debug(f"Default Video state shape is used. {self.state_shape}") + state_shape = self.state_shape + + # assert condition_latent is not None, "condition_latent should be provided" + + # if self.net.is_context_parallel_enabled: + # data_batch["latent_hint"] = split_inputs_cp(x=data_batch["latent_hint"], seq_dim=2, cp_group=self.net.cp_group) + + x0_fn = self.get_x0_fn_from_batch( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed, + ) + + if sigma_max is None: + sigma_max = self.sde.sigma_max + + if x_sigma_max is None: + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * sigma_max + ) + + if self.net.is_context_parallel_enabled: + x_sigma_max = _broadcast(x_sigma_max, to_tp=True, to_cp=True) + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + + samples = self.sampler( + x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max, solver_option=solver_option + ) + if self.net.is_context_parallel_enabled: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + return samples + + def get_patch_based_x0_fn( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + condition_latent: torch.Tensor = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + target_h: int = 2112, + target_w: int = 3840, + patch_h: int = 704, + patch_w: int = 1280, + seed_inference: int = 1, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + The function will split the input into patches, run inference on each patch, then stitch them together. + + Additional args to original function: + target_h (int): final stitched video height + target_w (int): final stitched video width + patch_h (int): video patch height for each network inference + patch_w (int): video patch width for each network inference + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 prediction + """ + assert patch_h <= target_h and patch_w <= target_w + # data_batch should be the one processed by self.get_data_and_condition + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + if hasattr(self, "is_extend_model") and self.is_extend_model: + # Add conditions for long video generation. + if condition_latent is None: + condition_latent = torch.zeros(data_batch["latent_hint"].shape, **self.tensor_kwargs) + num_condition_t = 0 + condition_video_augment_sigma_in_inference = 1000 + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent[:1], condition, num_condition_t + ) + uncondition.video_cond_bool = True # Not do cfg on condition frames + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent[:1], uncondition, num_condition_t + ) + # Add extra conditions for ctrlnet. + latent_hint = data_batch["latent_hint"] + hint_key = data_batch["hint_key"] + setattr(condition, hint_key, latent_hint) + if "use_none_hint" in data_batch and data_batch["use_none_hint"]: + setattr(uncondition, hint_key, None) + else: + setattr(uncondition, hint_key, latent_hint) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized() and not self.is_image_batch(data_batch): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + cp_group = parallel_state.get_context_parallel_group() + latent_hint = getattr(condition, hint_key) + latent_hint = split_inputs_cp(latent_hint, seq_dim=2, cp_group=cp_group) + + setattr(condition, "base_model", self.model.base_model) + setattr(uncondition, "base_model", self.model.base_model) + if hasattr(self, "hint_encoders"): + self.model.net.hint_encoders = self.hint_encoders + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor): + w, h = target_w, target_h + n_img_w = (w - 1) // patch_w + 1 + n_img_h = (h - 1) // patch_h + 1 + + overlap_size_w = overlap_size_h = 0 + if n_img_w > 1: + overlap_size_w = (n_img_w * patch_w - w) // (n_img_w - 1) + assert n_img_w * patch_w - overlap_size_w * (n_img_w - 1) == w + if n_img_h > 1: + overlap_size_h = (n_img_h * patch_h - h) // (n_img_h - 1) + assert n_img_h * patch_h - overlap_size_h * (n_img_h - 1) == h + + batch_images = noise_x + batch_sigma = sigma + output = [] + for idx, cur_images in enumerate(batch_images): + noise_x = cur_images.unsqueeze(0) + sigma = batch_sigma[idx : idx + 1] + condition.gt_latent = condition_latent[idx : idx + 1] + uncondition.gt_latent = condition_latent[idx : idx + 1] + setattr(condition, hint_key, latent_hint[idx : idx + 1]) + if getattr(uncondition, hint_key) is not None: + setattr(uncondition, hint_key, latent_hint[idx : idx + 1]) + + if self.is_image_batch(data_batch) or not issubclass(base_class, ExtendVideoDiffusionModel): + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + else: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + output.append(x0) + output = rearrange(torch.stack(output), "(n t) b ... -> (b n t) ...", n=n_img_h, t=n_img_w) # 8x3xhxw + final_output = merge_patches_into_video(output, overlap_size_h, overlap_size_w, n_img_h, n_img_w) + final_output = split_video_into_patches(final_output, patch_h, patch_w) + return final_output + + return x0_fn + + def generate_samples_from_patches( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Union[torch.Tensor, None] = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + x_sigma_max: Optional[torch.Tensor] = None, + sigma_max: float | None = None, + target_h: int = 2112, + target_w: int = 3840, + patch_h: int = 704, + patch_w: int = 1280, + ) -> Tensor: + """ + Generate samples from the batch using patch-based inference. During each denoising step, it will denoise each patch + separately then average the overlapping regions. + + Additional args to original function: + target_h (int): final stitched video height + target_w (int): final stitched video width + patch_h (int): video patch height for each network inference + patch_w (int): video patch width for each network inference + """ + assert patch_h <= target_h and patch_w <= target_w + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if is_image_batch: + log.debug("image batch, call base model generate_samples_from_batch") + return super().generate_samples_from_batch( + data_batch, + guidance=guidance, + seed=seed, + state_shape=state_shape, + n_sample=n_sample, + is_negative_prompt=is_negative_prompt, + num_steps=num_steps, + ) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + else: + log.debug(f"Default Video state shape is used. {self.state_shape}") + state_shape = self.state_shape + + x0_fn = self.get_patch_based_x0_fn( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + target_h=target_h, + target_w=target_w, + patch_h=patch_h, + patch_w=patch_w, + seed_inference=seed, + ) + + if sigma_max is None: + sigma_max = self.sde.sigma_max + + if x_sigma_max is None: + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * sigma_max + ) + + if self.net.is_context_parallel_enabled: + x_sigma_max = _broadcast(x_sigma_max, to_tp=True, to_cp=True) + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + + samples = self.sampler( + x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max, solver_option=solver_option + ) + if self.net.is_context_parallel_enabled: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + return samples + + @torch.no_grad() + def validation_step( + self, data: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + save generated videos + """ + raw_data, x0, condition = self.get_data_and_condition(data) + guidance = data["guidance"] + sigma_max = data["sigma_max"] + is_negative_prompt = data["is_negative_prompt"] + data = misc.to(data, **self.tensor_kwargs) + x_sigma_max = None + if sigma_max is not None: + x_sigma_max = self.get_x_from_clean(x0, sigma_max) + sample = self.generate_samples_from_batch( + data, + guidance=guidance, + # make sure no mismatch and also works for cp + state_shape=x0.shape[1:], + n_sample=x0.shape[0], + x_sigma_max=x_sigma_max, + sigma_max=sigma_max, + is_negative_prompt=is_negative_prompt, + ) + sample = self.decode(sample) + gt = raw_data + hint = data[data["hint_key"]][:, :3] + result = torch.cat([hint, sample], dim=3) + gt = torch.cat([hint, gt], dim=3) + caption = data["ai_caption"] + return {"gt": gt, "result": result, "caption": caption}, torch.tensor([0]).to(**self.tensor_kwargs) + + return VideoDiffusionModelWithCtrlWrapper + + +@video_ctrlnet_decorator +@ctrlnet_decorator +class VideoDiffusionModelWithCtrl(ExtendVideoDiffusionModel): + pass + + +@diffusion_fsdp_class_decorator +@video_ctrlnet_decorator +@ctrlnet_decorator +class VideoDiffusionFSDPModelWithCtrl(ExtendVideoDiffusionModel): + pass \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/training/models/model_image.py b/cosmos_transfer1/diffusion/training/models/model_image.py new file mode 100644 index 00000000..ce4edc00 --- /dev/null +++ b/cosmos_transfer1/diffusion/training/models/model_image.py @@ -0,0 +1,930 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from contextlib import contextmanager +from dataclasses import dataclass, fields +from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, TypeVar + +import numpy as np +import torch +import torch.nn.functional as F +from megatron.core import parallel_state +from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy, StateDictType +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +from torch.nn.modules.module import _IncompatibleKeys + +from cosmos_transfer1.diffusion.functional.batch_ops import batch_mul +from cosmos_transfer1.diffusion.module.blocks import FourierFeatures +from cosmos_transfer1.diffusion.module.pretrained_vae import BaseVAE +from cosmos_transfer1.diffusion.diffusion.modules.denoiser_scaling import EDMScaling +from cosmos_transfer1.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS, Sampler +from cosmos_transfer1.diffusion.training.functional.loss import create_per_sample_loss_mask +from cosmos_transfer1.diffusion.training.utils.fsdp_helper import apply_fsdp_checkpointing, hsdp_device_mesh +from cosmos_transfer1.diffusion.training.utils.optim_instantiate import get_base_scheduler +from cosmos_transfer1.diffusion.diffusion.types import DenoisePrediction +from cosmos_transfer1.utils import distributed, log, misc +from cosmos_transfer1.utils.ema import FastEmaModelUpdater +from cosmos_transfer1.utils.lazy_config import LazyDict +from cosmos_transfer1.utils.lazy_config import instantiate as lazy_instantiate +from cosmos_transfer1.utils.model import Model + + +@dataclass +class CosmosCondition: + crossattn_emb: torch.Tensor + crossattn_mask: torch.Tensor + padding_mask: Optional[torch.Tensor] = None + scalar_feature: Optional[torch.Tensor] = None + + def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: + return {f.name: getattr(self, f.name) for f in fields(self)} + + +class DiffusionModel(Model): + def __init__(self, config): + super().__init__() + + self.config = config + + # how many sample have been processed + self.sample_counter = 0 + self.precision = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[config.precision] + self.tensor_kwargs = {"device": "cuda", "dtype": self.precision} + log.warning(f"DiffusionModel: precision {self.precision}") + # Timer passed to network to detect slow ranks. + # 1. set data keys and data information + self.sigma_data = config.sigma_data + self.state_shape = list(config.latent_shape) + self.setup_data_key() + + # 2. setup up diffusion processing and scaling~(pre-condition), sampler + self.sde = lazy_instantiate(config.sde) + self.sampler = Sampler() + self.scaling = EDMScaling(self.sigma_data) + + # 3. vae + with misc.timer("DiffusionModel: set_up_vae"): + self.vae: BaseVAE = lazy_instantiate(config.vae) + assert ( + self.vae.latent_ch == self.state_shape[0] + ), f"latent_ch {self.vae.latent_ch} != state_shape {self.state_shape[0]}" + + # 4. Set up loss options, including loss masking, loss reduce and loss scaling + self.loss_masking: Optional[Dict] = config.loss_masking + self.loss_reduce = getattr(config, "loss_reduce", "mean") + assert self.loss_reduce in ["mean", "sum"] + self.loss_scale = getattr(config, "loss_scale", 1.0) + log.critical(f"Using {self.loss_reduce} loss reduce with loss scale {self.loss_scale}") + log.critical(f"Enable loss masking: {config.loss_mask_enabled}") + + # 5. diffusion neural networks part + self.set_up_model() + + def setup_data_key(self) -> None: + self.input_data_key = self.config.input_data_key + + def build_model(self) -> torch.nn.ModuleDict: + config = self.config + net = lazy_instantiate(config.net) + conditioner = lazy_instantiate(config.conditioner) + logvar = torch.nn.Sequential( + FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False) + ) + + return torch.nn.ModuleDict( + { + "net": net, + "conditioner": conditioner, + "logvar": logvar, + } + ) + + @misc.timer("DiffusionModel: set_up_model") + def set_up_model(self): + config = self.config + self.model = self.build_model() + if config.ema.enabled: + with misc.timer("DiffusionModel: instantiate ema"): + config.ema.model = self.model + self.model_ema = lazy_instantiate(config.ema) + config.ema.model = None + else: + self.model_ema = None + + @property + def net(self): + return self.model.net + + @property + def conditioner(self): + return self.model.conditioner + + def on_before_zero_grad( + self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int + ) -> None: + """ + update the model_ema + """ + if self.config.ema.enabled: + self.model_ema.update_average(self.model, iteration) + + def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: + if self.config.ema.enabled: + self.model_ema.to(dtype=torch.float32) + if hasattr(self.vae, "reset_dtype"): + self.vae.reset_dtype() + self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) + + if hasattr(self.config, "use_torch_compile") and self.config.use_torch_compile: # compatible with old config + if torch.__version__ < "2.3": + log.warning( + "torch.compile in Pytorch version older than 2.3 doesn't work well with activation checkpointing.\n" + "It's very likely there will be no significant speedup from torch.compile.\n" + "Please use at least 24.04 Pytorch container." + ) + # Increasing cache size. It's required because of the model size and dynamic input shapes resulting in + # multiple different triton kernels. For 28 TransformerBlocks, the cache limit of 256 should be enough for + # up to 9 different input shapes, as 28*9 < 256. If you have more Blocks or input shapes, and you observe + # graph breaks at each Block (detectable with torch._dynamo.explain) or warnings about + # exceeding cache limit, you may want to increase this size. + # Starting with 24.05 Pytorch container, the default value is 256 anyway. + # You can read more about it in the comments in Pytorch source code under path torch/_dynamo/cache_size.py. + torch._dynamo.config.accumulated_cache_size_limit = 256 + # dynamic=False means that a separate kernel is created for each shape. It incurs higher compilation costs + # at initial iterations, but can result in more specialized and efficient kernels. + # dynamic=True currently throws errors in pytorch 2.3. + self.model.net = torch.compile(self.model.net, dynamic=False, disable=not self.config.use_torch_compile) + + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, torch.Tensor], + x0_from_data_batch: torch.Tensor, + x0: torch.Tensor, + condition: CosmosCondition, + epsilon: torch.Tensor, + sigma: torch.Tensor, + ): + """ + Compute loss givee epsilon and sigma + + This method is responsible for computing loss give epsilon and sigma. It involves: + 1. Adding noise to the input data using the SDE process. + 2. Passing the noisy data through the network to generate predictions. + 3. Computing the loss based on the difference between the predictions and the original data, \ + considering any configured loss weighting. + + Args: + data_batch (dict): raw data batch draw from the training data loader. + x0_from_data_batch: raw image/video + x0: image/video latent + condition: text condition + epsilon: noise + sigma: noise level + + Returns: + tuple: A tuple containing four elements: + - dict: additional data that used to debug / logging / callbacks + - Tensor 1: kendall loss, + - Tensor 2: MSE loss, + - Tensor 3: EDM loss + + Raises: + AssertionError: If the class is conditional, \ + but no number of classes is specified in the network configuration. + + Notes: + - The method handles different types of conditioning + - The method also supports Kendall's loss + """ + # Get the mean and stand deviation of the marginal probability distribution. + mean, std = self.sde.marginal_prob(x0, sigma) + # Generate noisy observations + xt = mean + batch_mul(std, epsilon) # corrupted data + # make prediction + model_pred = self.denoise(xt, sigma, condition) + # loss weights for different noise levels + weights_per_sigma = self.get_per_sigma_loss_weights(sigma=sigma) + # extra weight for each sample, for example, aesthetic weight, camera weight + weights_per_sample = self.get_per_sample_weight(data_batch, x0_from_data_batch.shape[0]) + # extra loss mask for each sample, for example, human faces, hands + loss_mask_per_sample = self.get_per_sample_loss_mask(data_batch, x0_from_data_batch.shape, x0.shape) + pred_mse = (x0 - model_pred.x0) ** 2 * loss_mask_per_sample + edm_loss = batch_mul(pred_mse, weights_per_sigma * weights_per_sample) + if self.config.loss_add_logvar: + kendall_loss = batch_mul(edm_loss, torch.exp(-model_pred.logvar).view(-1)).flatten( + start_dim=1 + ) + model_pred.logvar.view(-1, 1) + else: + kendall_loss = edm_loss.flatten(start_dim=1) + output_batch = { + "x0": x0, + "xt": xt, + "sigma": sigma, + "weights_per_sigma": weights_per_sigma, + "weights_per_sample": weights_per_sample, + "loss_mask_per_sample": loss_mask_per_sample, + "condition": condition, + "model_pred": model_pred, + "mse_loss": pred_mse.mean(), + "edm_loss": edm_loss.mean(), + } + return output_batch, kendall_loss, pred_mse, edm_loss + + def training_step( + self, data_batch: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + Performs a single training step for the diffusion model. + + This method is responsible for executing one iteration of the model's training. It involves: + 1. Adding noise to the input data using the SDE process. + 2. Passing the noisy data through the network to generate predictions. + 3. Computing the loss based on the difference between the predictions and the original data, \ + considering any configured loss weighting. + + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + + Returns: + tuple: A tuple containing two elements: + - dict: additional data that used to debug / logging / callbacks + - Tensor: The computed loss for the training step as a PyTorch Tensor. + + Raises: + AssertionError: If the class is conditional, \ + but no number of classes is specified in the network configuration. + + Notes: + - The method handles different types of conditioning + - The method also supports Kendall's loss + """ + # Get the input data to noise and denoise~(image, video) and the corresponding conditioner. + x0_from_data_batch, x0, condition = self.get_data_and_condition(data_batch) + + # Sample pertubation noise levels and N(0, 1) noises + sigma, epsilon = self.draw_training_sigma_and_epsilon(x0.size(), condition) + + output_batch, kendall_loss, pred_mse, edm_loss = self.compute_loss_with_epsilon_and_sigma( + data_batch, x0_from_data_batch, x0, condition, epsilon, sigma + ) + + if self.loss_reduce == "mean": + kendall_loss = kendall_loss.mean() * self.loss_scale + elif self.loss_reduce == "sum": + kendall_loss = kendall_loss.sum(dim=1).mean() * self.loss_scale + else: + raise ValueError(f"Invalid loss_reduce: {self.loss_reduce}") + + return output_batch, kendall_loss + + def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: CosmosCondition) -> DenoisePrediction: + """ + Performs denoising on the input noise data, noise level, and condition + + Args: + xt (torch.Tensor): The input noise data. + sigma (torch.Tensor): The noise level. + condition (CosmosCondition): conditional information, generated from self.conditioner + + Returns: + DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ + noise prediction (eps_pred) and optional confidence (logvar). + """ + + if getattr(self.config, "use_dummy_temporal_dim", False): + # When using video DiT model for image, we need to use a dummy temporal dimension. + xt = xt.unsqueeze(2) + + xt = xt.to(**self.tensor_kwargs) + sigma = sigma.to(**self.tensor_kwargs) + # get precondition for the network + c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma) + + # forward pass through the network + net_output = self.net( + x=batch_mul(c_in, xt), # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf + timesteps=c_noise, # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf + **condition.to_dict(), + ) + + logvar = self.model.logvar(c_noise) + x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output) + + # get noise prediction based on sde + eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma) + + if getattr(self.config, "use_dummy_temporal_dim", False): + x0_pred = x0_pred.squeeze(2) + eps_pred = eps_pred.squeeze(2) + + return DenoisePrediction(x0_pred, eps_pred, logvar) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + return self.vae.encode(state) * self.sigma_data + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + return self.vae.decode(latent / self.sigma_data) + + def draw_training_sigma_and_epsilon(self, x0_size: int, condition: Any) -> torch.Tensor: + del condition + batch_size = x0_size[0] + epsilon = torch.randn(x0_size, **self.tensor_kwargs) + return self.sde.sample_t(batch_size).to(**self.tensor_kwargs), epsilon + + def get_data_and_condition(self, data_batch: dict[str, torch.Tensor]) -> Tuple[torch.Tensor, CosmosCondition]: + """ + processing data batch draw from data loader and return data and condition that used for denoising task + + Returns: + raw_state (tensor): the image / video data that feed to vae + latent_state (tensor): nosie-free state, the vae latent state + condition (CosmosCondition): condition information for conditional generation. Generated from conditioner + """ + raw_state = data_batch[self.input_data_key] + latent_state = self.encode(raw_state) + condition = self.conditioner(data_batch) + return raw_state, latent_state, condition + + def get_per_sample_weight(self, data_batch: dict[str, torch.Tensor], batch_size: int): + r""" + extra weight for each sample, for example, aesthetic weight + Args: + data_batch: raw data batch draw from the training data loader. + batch_size: int, the batch size of the input data + """ + aesthetic_cfg = getattr(self.config, "aesthetic_finetuning", None) + if (aesthetic_cfg is not None) and getattr(aesthetic_cfg, "enabled", False): + sample_weight = data_batch["aesthetic_weight"] + else: + sample_weight = torch.ones(batch_size, **self.tensor_kwargs) + + camera_cfg = getattr(self.config, "camera_sample_weight", None) + if (camera_cfg is not None) and getattr(camera_cfg, "enabled", False): + sample_weight *= 1 + (data_batch["camera_attributes"][:, 1:].sum(dim=1) != 0) * (camera_cfg.weight - 1) + return sample_weight + + def get_per_sample_loss_mask(self, data_batch, raw_x_shape, latent_x_shape): + """ + extra loss mask for each sample, for example, human faces, hands. + + Args: + data_batch (dict): raw data batch draw from the training data loader. + raw_x_shape (tuple): shape of the input data. We need the raw_x_shape for necessary resize operation. + latent_x_shape (tuple): shape of the latent data + """ + if self.config.loss_mask_enabled: + raw_x_shape = [raw_x_shape[0], 1, *raw_x_shape[2:]] + weights = create_per_sample_loss_mask( + self.loss_masking, data_batch, raw_x_shape, torch.get_default_dtype(), "cuda" + ) + return F.interpolate(weights, size=latent_x_shape[2:], mode="bilinear") + + return 1.0 + + def get_per_sigma_loss_weights(self, sigma: torch.Tensor): + """ + Args: + sigma (tensor): noise level + + Returns: + loss weights per sigma noise level + """ + return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + def generate_samples(self, batch_size: int, condition: CosmosCondition) -> torch.Tensor: + """ + Generate samples with given condition. It is WITHOUT classifier-free-guidance. + + Args: + batch_size (int): + condition (CosmosCondition): condition information generated from self.conditioner + """ + x_sigma_max = torch.randn(batch_size, *self.state_shape, **self.tensor_kwargs) * self.sde.sigma_max + + def x0_fn(x, t): + return self.denoise(x, t, condition).x0 # ODE function + + return self.sampler(x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max) + + def generate_cfg_samples( + self, batch_size: int, condition: CosmosCondition, uncondition: CosmosCondition, guidance=1.5 + ) -> torch.Tensor: + """ + Generate samples with with classifier-free-guidance. + + Args: + batch_size (int): + condition (CosmosCondition): condition information generated from self.conditioner + uncondition (CosmosCondition): uncondition information, possibily generated from self.conditioner + """ + x_sigma_max = torch.randn(batch_size, *self.state_shape, **self.tensor_kwargs) * self.sde.sigma_max + + def x0_fn(x, t): + cond_x0 = self.denoise(x, t, condition).x0 + uncond_x0 = self.denoise(x, t, uncondition).x0 + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return self.sampler(x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max) + + def get_x0_fn_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Optional[Tuple] = None, + n_sample: Optional[int] = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + ) -> torch.Tensor: + """ + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + guidance (float): guidance weights + seed (int): random seed + state_shape (tuple): shape of the state, default to self.state_shape if not provided + n_sample (int): number of samples to generate + is_negative_prompt (bool): use negative prompt t5 in uncondition if true + num_steps (int): number of steps for the diffusion process + solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) + """ + x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt) + batch_size = n_sample or data_batch[self.input_data_key].shape[0] + state_shape = state_shape or self.state_shape + x_sigma_max = ( + misc.arch_invariant_rand( + (batch_size,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * self.sde.sigma_max + ) + return self.sampler( + x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max, num_steps=num_steps, solver_option=solver_option + ) + + @torch.no_grad() + def validation_step( + self, data: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + Current code does nothing. + """ + return {}, torch.tensor(0).to(**self.tensor_kwargs) + + @torch.no_grad() + def forward(self, xt, t, condition: CosmosCondition): + """ + Performs denoising on the input noise data, noise level, and condition + + Args: + xt (torch.Tensor): The input noise data. + sigma (torch.Tensor): The noise level. + condition (CosmosCondition): conditional information, generated from self.conditioner + + Returns: + DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ + noise prediction (eps_pred) and optional confidence (logvar). + """ + return self.denoise(xt, t, condition) + + def init_optimizer_scheduler( + self, optimizer_config: LazyDict, scheduler_config: LazyDict + ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: + """Creates the optimizer and scheduler for the model. + + Args: + config_model (ModelConfig): The config object for the model. + + Returns: + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + """ + optimizer = lazy_instantiate(optimizer_config, model=self.model) + scheduler = get_base_scheduler(optimizer, self, scheduler_config) + return optimizer, scheduler + + def state_dict(self) -> Dict[str, Any]: + """ + Returns the current state of the model as a dictionary. + + Returns: + Dict: The current state of the model as a dictionary. + """ + return { + "model": self.model.state_dict(), + "ema": self.model_ema.state_dict() if self.config.ema.enabled else None, + } + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): + """ + Loads a state dictionary into the model and optionally its EMA counterpart. + Different from torch strict=False mode, the method will not raise error for unmatched state shape while raise warning. + + Parameters: + state_dict (Mapping[str, Any]): A dictionary containing separate state dictionaries for the model and + potentially for an EMA version of the model under the keys 'model' and 'ema', respectively. + strict (bool, optional): If True, the method will enforce that the keys in the state dict match exactly + those in the model and EMA model (if applicable). Defaults to True. + assign (bool, optional): If True and in strict mode, will assign the state dictionary directly rather than + matching keys one-by-one. This is typically used when loading parts of state dicts + or using customized loading procedures. Defaults to False. + """ + if strict: + # the converted tpsp checkpoint has "ema" and it is None + if self.config.ema.enabled and state_dict["ema"] is not None: + ema_results: _IncompatibleKeys = self.model_ema.load_state_dict( + state_dict["ema"], strict=strict, assign=assign + ) + reg_results: _IncompatibleKeys = self.model.load_state_dict( + state_dict["model"], strict=strict, assign=assign + ) + if self.config.ema.enabled and state_dict["ema"] is not None: + return _IncompatibleKeys( + ema_results.missing_keys + reg_results.missing_keys, + ema_results.unexpected_keys + reg_results.unexpected_keys, + ) + return reg_results + else: + from cosmos_transfer1.diffusion.inference.inference_utils import non_strict_load_model + + log.critical("load model in non-strict mode") + log.critical(non_strict_load_model(self.model, state_dict["model"]), rank0_only=False) + if self.config.ema.enabled and state_dict["ema"] is not None: + log.critical("load ema model in non-strict mode") + log.critical(non_strict_load_model(self.model_ema, state_dict["ema"]), rank0_only=False) + + def get_ckpt_postfix(self) -> Tuple[str, int, int]: + """Get the checkpoint file postfix. + + Args: + iteration (int): The current iteration number. + + Returns: + postfix (str): The postfix of the checkpoint file. + rank_to_save ema (int), we will not save each ema model in each rank, \ + ema model with same rate will be saved once + total_ema_num (int) + """ + total_ema_num = min(self.config.ema.num, distributed.get_world_size()) + rank = distributed.get_rank() + if rank == 0: + return "", 0, total_ema_num + if self.config.ema.enabled: + if rank < self.config.ema.num: + return f"_RANK{rank}", rank, total_ema_num + return "", 0, total_ema_num # use rank 0 to save the checkpoint + + @contextmanager + def ema_scope(self, context=None, is_cpu=False): + if self.config.ema.enabled: + self.model_ema.cache(self.model.parameters(), is_cpu=is_cpu) + self.model_ema.copy_to(self.model) + if context is not None: + log.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.config.ema.enabled: + self.model_ema.restore(self.model.parameters()) + if context is not None: + log.info(f"{context}: Restored training weights") + + +T = TypeVar("T", bound=DiffusionModel) + + +def diffusion_fsdp_class_decorator(base_class: Type[T]) -> Type[T]: + """ + Decorator for the FSDP class for the diffusion model, which handles the FSDP specific logic for the diffusion model. + """ + + class FSDPClass(base_class): + """ + Handle FSDP specific logic for the diffusion model. Including: + - FSDP model initialization + - FSDP model / optimizer save and loading + - Different from the original DiffusionModel, the impl of multi-rank EMA is a bit hacky. \ + We need to make sure sharded model weights for EMA and regular model are the same. + """ + + def __init__(self, config, fsdp_checkpointer: Any): + self.fsdp_checkpointer = fsdp_checkpointer + super().__init__(config) + + def set_up_model(self): + config = self.config + + # 1. build FSDP sharding strategy and device_mesh + strategy = { + "full": ShardingStrategy.FULL_SHARD, + "hybrid": ShardingStrategy.HYBRID_SHARD, + }[config.fsdp.sharding_strategy] + log.critical(f"Using {strategy} sharding strategy for FSDP") + + if config.fsdp.sharding_strategy == "hybrid": + sharding_group_size = getattr(config.fsdp, "sharding_group_size", 8) + device_mesh = hsdp_device_mesh( + sharding_group_size=sharding_group_size, + ) + shard_group = device_mesh.get_group(mesh_dim="shard") + replicate_group = device_mesh.get_group(mesh_dim="replicate") + fsdp_process_group = (shard_group, replicate_group) + else: + device_mesh = hsdp_device_mesh( + sharding_group_size=distributed.get_world_size(), + ) + shard_group = device_mesh.get_group(mesh_dim="shard") + fsdp_process_group = shard_group + + # We piggyback the `device_mesh` to megatron-core's `parallel_state` for global access. + # This is not megatron-core's original API. + parallel_state.fsdp_device_mesh = device_mesh + + def get_wrap_policy(_model): + if not hasattr(_model.net, "fsdp_wrap_block_cls"): + raise ValueError( + "Networks does not have fsdp_wrap_block_cls attribute, please check the net definition" + ) + fsdp_blocks_cls = _model.net.fsdp_wrap_block_cls + fsdp_blocks_cls = ( + list(fsdp_blocks_cls) if isinstance(fsdp_blocks_cls, (list, tuple, set)) else [fsdp_blocks_cls] + ) + log.critical(f"Using FSDP blocks {fsdp_blocks_cls}") + + log.critical(f"Using wrap policy {config.fsdp.policy}") + if config.fsdp.policy == "size": + min_num_params = getattr(config.fsdp, "min_num_params", 100) + log.critical(f"Using {min_num_params} as the minimum number of parameters for auto-wrap policy") + wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) + else: + from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + + wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=set(fsdp_blocks_cls), + ) + return wrap_policy + + # 2. build naive pytorch model and load weights if exists + replica_idx, shard_idx = device_mesh.get_coordinate() + # 2.1 handle ema case first, since float32 is more expensive + if config.ema.enabled: + with misc.timer("Creating PyTorch model and loading weights for ema"): + model_ema = self.build_model().float() + model_ema.cuda().eval().requires_grad_(False) + if distributed.get_rank() == 0: + # only load model in rank0 to reduce network traffic + self.fsdp_checkpointer.load_model_during_init(model_ema, is_ema=True) + # sync ema model weights from rank0 + with misc.timer("Sync model states for EMA model"): + #! this is IMPORTANT, see the following comment about regular model for details + #! we broadcast the ema model first, since it is fp32 and costs more memory + distributed.sync_model_states(model_ema, device_mesh.get_group(mesh_dim="shard")) + torch.cuda.empty_cache() + distributed.sync_model_states(model_ema, device_mesh.get_group(mesh_dim="replicate")) + torch.cuda.empty_cache() + # for ema model with dfiferent rate, we download the model when necessary + if shard_idx == 0 and replica_idx > 0 and replica_idx < config.ema.num: + print("loading ema model in rank", replica_idx) + self.fsdp_checkpointer.load_model_during_init( + model_ema, + is_ema=True, + ema_id=replica_idx, + ) + print("finish loading ema model in rank", replica_idx) + # 2.1.2 create FSDP model for ema model + with misc.timer("Creating FSDP model for EMA model"): + self.model_ema = FSDP( + model_ema, + sync_module_states=True, # it can reduce network traffic by only loading model in rank0 and sync + process_group=device_mesh.get_group(mesh_dim=1), + sharding_strategy=ShardingStrategy.FULL_SHARD, + auto_wrap_policy=get_wrap_policy(model_ema), + device_id=torch.cuda.current_device(), + limit_all_gathers=True, + ) + + # extra ema model upate logic to the model + self.model_ema_worker = FastEmaModelUpdater() + s = 0.1 + replica_idx, shard_idx = device_mesh.get_coordinate() + divider = 2**replica_idx if replica_idx < config.ema.num else 1 + if replica_idx < config.ema.num: + if shard_idx == 0: + print(f"EMA: rank {replica_idx}, rate {config.ema.rate / divider}") + s = config.ema.rate / divider + self.ema_exp_coefficient = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max() + + torch.cuda.empty_cache() + + # 2.2 handle regular model + with misc.timer("Creating PyTorch model and loading weights for regular model"): + model = self.build_model().cuda().to(**self.tensor_kwargs) + + if distributed.get_rank() == 0: + # only load model in rank0 to reduce network traffic and sync later + self.fsdp_checkpointer.load_model_during_init(model, is_ema=False) + + #! overwrite the forward method so that it will invoke the FSDP-specific pre- and post-forward sharding logic + model.forward = super().training_step + #! this is IMPORTANT, though following two lines are identical to sync_module_states=True in FSDP + #! we do it twice so that following line can warm up and avoid OOM in aws 128+ nodes settings + #! qsh hypothesize that it is due to overhead of initialization of nccl network communication; + #! without it, peak mem : reg_model + ema_model + FSDP overhead + nccl communication initialization overhead + #! with it, peak men: reg_model + ema_model + FSDP overhead + #! it is tricky, but it works! + with misc.timer("Sync model states for regular model"): + distributed.sync_model_states(model, device_mesh.get_group(mesh_dim="shard")) + torch.cuda.empty_cache() + distributed.sync_model_states(model, device_mesh.get_group(mesh_dim="replicate")) + torch.cuda.empty_cache() + + with misc.timer("Creating FSDP model"): + self.model = FSDP( + model.to(**self.tensor_kwargs), + sync_module_states=True, # it can reduce network traffic by only loading model in rank0 and sync + sharding_strategy=strategy, + auto_wrap_policy=get_wrap_policy(model), + process_group=fsdp_process_group, + limit_all_gathers=True, + ) + + if self.config.fsdp.checkpoint: + fsdp_blocks_cls = model.net.fsdp_wrap_block_cls + fsdp_blocks_cls = ( + list(fsdp_blocks_cls) + if isinstance(fsdp_blocks_cls, (list, tuple, set)) + else [fsdp_blocks_cls] + ) + log.critical(f"Applying FSDP checkpointing with FSDP blocks: {fsdp_blocks_cls}") + apply_fsdp_checkpointing(self.model, list_block_cls=fsdp_blocks_cls) + + torch.cuda.empty_cache() + + def on_before_zero_grad( + self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int + ) -> None: + del scheduler, optimizer + + if self.config.ema.enabled: + # calculate beta for EMA update + if iteration == 0: + beta = 0.0 + else: + i = iteration + 1 + beta = (1 - 1 / i) ** (self.ema_exp_coefficient + 1) + self.model_ema_worker.update_average(self.model, self.model_ema, beta=beta) + + def training_step( + self, data_batch: Dict[str, torch.Tensor], iteration: int + ) -> Tuple[Dict[str, torch.Tensor] | torch.Tensor]: + # ! Important!!! + # ! make sure the training step is the same as the forward method~(training_step in the super class) + # ! this is necessary to trigger the FSDP-specific pre- and post-forward sharding logic + return self.model(data_batch, iteration) + + def state_dict(self) -> Dict: + raise NotImplementedError( + "FSDPDiffModle does not support state_dict, use state_dict_model and FSDPCheckpointer" + ) + + @misc.timer("FSDP state_dict_model") + def state_dict_model(self) -> Dict: + with FSDP.summon_full_params(self.model): + pass + with FSDP.state_dict_type( + self.model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + ): + model_state = self.model.state_dict() + if self.config.ema.enabled: + with FSDP.summon_full_params(self.model_ema): + pass + with FSDP.state_dict_type( + self.model_ema, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + ema_model_state = self.model_ema.state_dict() + else: + ema_model_state = None + return { + "model": model_state, + "ema": ema_model_state, + } + + def load_state_dict(self, state_dict: Dict, strict: bool = True, assign: bool = False) -> None: + raise NotImplementedError("FSDPDiffModle does not support load_state_dict, using FSDPCheckpointer") + + def init_optimizer_scheduler( + self, optimizer_config: LazyDict, scheduler_config: LazyDict + ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: + optimizer, scheduler = super().init_optimizer_scheduler(optimizer_config, scheduler_config) + self.fsdp_checkpointer.load_optim_scheduler_during_init( + self.model, + optimizer, + scheduler, + ) + return optimizer, scheduler + + @contextmanager + def ema_scope(self, context=None, is_cpu=False): + if self.config.ema.enabled: + self.model_ema_worker.cache(self.model.parameters(), is_cpu=is_cpu) + self.model_ema_worker.copy_to(src_model=self.model_ema, tgt_model=self.model) + if context is not None: + log.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.config.ema.enabled: + self.model_ema_worker.restore(self.model.parameters()) + if context is not None: + log.info(f"{context}: Restored training weights") + + def get_ckpt_postfix(self) -> Tuple[str, int]: + """Get the checkpoint file postfix. check FSDPCheckpointer for more details + + Args: + iteration (int): The current iteration number. + + Returns: + postfix (str): The postfix of the checkpoint file. + replicate_idx, shard_idx (int), current gpu replicate_idx, shard_idx in FSDP \ + we will not save each ema model in each GPU, \ + ema model with same rate will be saved once + total_ema_num (int) + """ + mesh_shape = parallel_state.fsdp_device_mesh.shape + total_ema_num = min(self.config.ema.num, mesh_shape[0]) + replicate_idx, shard_idx = parallel_state.fsdp_device_mesh.get_coordinate() + if replicate_idx == 0: + return "", 0, shard_idx, total_ema_num + if self.config.ema.enabled: + if replicate_idx < self.config.ema.num: + return f"_RANK{replicate_idx}", replicate_idx, shard_idx, total_ema_num + return "", replicate_idx, shard_idx, total_ema_num + + return FSDPClass + + +@diffusion_fsdp_class_decorator +class FSDPDiffusionModel(DiffusionModel): + pass diff --git a/cosmos_transfer1/diffusion/training/models/model_multiview.py b/cosmos_transfer1/diffusion/training/models/model_multiview.py new file mode 100644 index 00000000..255df07c --- /dev/null +++ b/cosmos_transfer1/diffusion/training/models/model_multiview.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_transfer1.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS +from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_transfer1.diffusion.training.models.model import DiffusionModel, broadcast_condition +from cosmos_transfer1.diffusion.training.models.model_image import CosmosCondition, diffusion_fsdp_class_decorator +from cosmos_transfer1.utils import log, misc + + +class MultiviewDiffusionModel(DiffusionModel): + def __init__(self, config): + super().__init__(config) + self.n_views = config.n_views + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + state = rearrange(state, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + encoded_state = self.vae.encode(state) + encoded_state = rearrange(encoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) * self.sigma_data + return encoded_state + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + decoded_state = self.vae.decode(latent / self.sigma_data) + decoded_state = rearrange(decoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + return decoded_state + + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, torch.Tensor], + x0_from_data_batch: torch.Tensor, + x0: torch.Tensor, + condition: CosmosCondition, + epsilon: torch.Tensor, + sigma: torch.Tensor, + ): + if self.is_image_batch(data_batch): + # Turn off CP + self.net.disable_context_parallel() + else: + if parallel_state.is_initialized(): + if parallel_state.get_context_parallel_world_size() > 1: + # Turn on CP + cp_group = parallel_state.get_context_parallel_group() + self.net.enable_context_parallel(cp_group) + log.debug("[CP] Split x0 and epsilon") + + x0 = rearrange(x0, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + epsilon = rearrange(epsilon, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + + x0 = split_inputs_cp(x=x0, seq_dim=2, cp_group=self.net.cp_group) + epsilon = split_inputs_cp(x=epsilon, seq_dim=2, cp_group=self.net.cp_group) + + x0 = rearrange(x0, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + epsilon = rearrange(epsilon, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + output_batch, kendall_loss, pred_mse, edm_loss = super( + DiffusionModel, self + ).compute_loss_with_epsilon_and_sigma(data_batch, x0_from_data_batch, x0, condition, epsilon, sigma) + if not self.is_image_batch(data_batch): + if self.loss_reduce == "sum" and parallel_state.get_context_parallel_world_size() > 1: + kendall_loss *= parallel_state.get_context_parallel_world_size() + + return output_batch, kendall_loss, pred_mse, edm_loss + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + x_sigma_max: Optional[torch.Tensor] = None, + sigma_max: float | None = None, + guidance_other: Union[float, None] = None, + ) -> Tensor: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + guidance (float): guidance weights + seed (int): random seed + state_shape (tuple): shape of the state, default to self.state_shape if not provided + n_sample (int): number of samples to generate + is_negative_prompt (bool): use negative prompt t5 in uncondition if true + num_steps (int): number of steps for the diffusion process + solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + x0_fn = self.get_x0_fn_from_batch( + data_batch, guidance, is_negative_prompt=is_negative_prompt, guidance_other=guidance_other + ) + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * self.sde.sigma_max + ) + if self.net.is_context_parallel_enabled: + x_sigma_max = rearrange(x_sigma_max, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + + x_sigma_max = rearrange(x_sigma_max, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + samples = self.sampler( + x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max, solver_option=solver_option + ) + if self.net.is_context_parallel_enabled: + samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + return samples + + def get_x0_fn_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + guidance_other: Union[float, None] = None, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + if guidance_other is not None: + # assume this is for inference time trajectory guidance for now + assert not parallel_state.is_initialized(), "Parallel state not supported with two guidances." + condition_other = copy.deepcopy(uncondition) + condition_other.trajectory = condition.trajectory + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + cond_other_x0 = self.denoise(noise_x, sigma, condition_other).x0 + + raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + guidance_other * (cond_other_x0 - uncond_x0) + + if "guided_image" in data_batch: + assert False, "not supported" + return raw_x0 + + else: + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + if "guided_image" in data_batch: + # replacement trick that enables inpainting with base model + assert "guided_mask" in data_batch, "guided_mask should be in data_batch if guided_image is present" + guide_image = data_batch["guided_image"] + guide_mask = data_batch["guided_mask"] + raw_x0 = guide_mask * guide_image + (1 - guide_mask) * raw_x0 + return raw_x0 + + return x0_fn + + +@diffusion_fsdp_class_decorator +class FSDPDiffusionModel(MultiviewDiffusionModel): + pass diff --git a/cosmos_transfer1/diffusion/training/utils/fsdp_helper.py b/cosmos_transfer1/diffusion/training/utils/fsdp_helper.py new file mode 100644 index 00000000..1027504a --- /dev/null +++ b/cosmos_transfer1/diffusion/training/utils/fsdp_helper.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager +from functools import partial + +import torch +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, +) +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp._runtime_utils import ( + _post_forward, + _post_forward_reshard, + _pre_forward, + _pre_forward_unshard, + _root_pre_forward, +) +from torch.distributed.utils import _p_assert + +from cosmos_transfer1.utils import distributed, log + + +def apply_fsdp_checkpointing(model, list_block_cls): + """apply activation checkpointing to model + returns None as model is updated directly + """ + log.critical("--> applying fdsp activation checkpointing...") + non_reentrant_wrapper = partial( + checkpoint_wrapper, + # offload_to_cpu=False, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + + def check_fn(submodule): + result = False + for block_cls in list_block_cls: + if isinstance(submodule, block_cls): + result = True + break + return result + + apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) + + +@contextmanager +def possible_fsdp_scope( + model: torch.nn.Module, +): + enabled = isinstance(model, FSDP) + if enabled: + assert not torch.is_grad_enabled(), "FSDP context should be entered with grad disabled" + handle = model._handle + args, kwargs = [0], dict(dummy=0) + with torch.autograd.profiler.record_function("FullyShardedDataParallel.possible_fsdp_scope"): + args, kwargs = _root_pre_forward(model, model, args, kwargs) + unused = None + args, kwargs = _pre_forward( + model, + handle, + _pre_forward_unshard, + model._fsdp_wrapped_module, + args, + kwargs, + ) + if handle: + _p_assert( + handle.flat_param.device == model.compute_device, + "Expected `FlatParameter` to be on the compute device " + f"{model.compute_device} but got {handle.flat_param.device}", + ) + try: + yield None + finally: + if enabled: + output = {"output": 1} + _post_forward(model, handle, _post_forward_reshard, model, unused, output) + + +def hsdp_device_mesh(replica_group_size=None, sharding_group_size=None, device=None): + """ + Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training. + + This function requires explicit sizes for replica and sharding groups to accommodate models + whose GPU fit is unknown, providing flexibility in distributed training setups. + + Args: + replica_group_size (int): The size of each replica group. Must be provided to ensure + the model fits within the available resources. + sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to + ensure the correct distribution of model parameters. + device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda" + with the local rank as the device index. + + Returns: + A device mesh object compatible with FSDP. + + Raises: + ValueError: If replica_group_size or sharding_group_size are not provided, or if the + world size is not evenly divisible by the sharding group size. + RuntimeError: If a valid device mesh cannot be created. + + Usage: + If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then: + Sharding_Group_Size = 4 + Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups + >>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size) + >>> sharded_model = FSDP(model, device_mesh=device_mesh, ...) + """ + + # world_size = int(os.getenv("WORLD_SIZE", "1")) + world_size = distributed.get_world_size() + if sharding_group_size is None: + sharding_group_size = min(world_size, 8) + sharding_group_size = min(sharding_group_size, world_size) + if replica_group_size is None: + replica_group_size = world_size // sharding_group_size + + device = device or "cuda" + + if world_size % sharding_group_size != 0: + raise ValueError( + f"World size {world_size} is not evenly divisible by " f"sharding group size {sharding_group_size}." + ) + + if (world_size // sharding_group_size) % replica_group_size != 0: + raise ValueError( + f"The calculated number of replica groups is not evenly divisible by " + f"replica_group_size {replica_group_size}." + ) + + device_mesh = init_device_mesh( + device, (replica_group_size, sharding_group_size), mesh_dim_names=("replicate", "shard") + ) + if device_mesh is None: + raise RuntimeError("Failed to create a valid device mesh.") + + log.critical( + f"Device mesh initialized with replica group size {replica_group_size} and sharding group size {sharding_group_size}" + ) + + return device_mesh From 0ec0a46c56ad2ad2407bd96ddb686e0946afd54b Mon Sep 17 00:00:00 2001 From: Qianli Ma Date: Fri, 11 Apr 2025 19:17:05 -0700 Subject: [PATCH 03/10] feat: add example Dataset class, add data augmentors, update config --- cosmos_transfer1/diffusion/conditioner.py | 1 + .../diffusion/config/base/conditioner.py | 5 + .../experiment/ctrl_7b_tp_121frames.py | 1 + .../diffusion/config/transfer/registry.py | 3 + .../diffusion/datasets/augmentor_provider.py | 171 +++++++++ .../diffusion/datasets/augmentors.py | 241 ++++++++++++ .../datasets/augmentors/basic_augmentors.py | 241 ++++++++++++ .../datasets/augmentors/merge_datadict.py | 54 +++ .../augmentors/text_transforms_for_video.py | 136 +++++++ .../diffusion/datasets/dataset_utils.py | 193 ++++++++++ .../diffusion/datasets/video_dataset.py | 353 ++++++++++++++++++ cosmos_transfer1/diffusion/training/train.py | 10 +- .../post-training_cosmos_transfer_7b_edge.md | 6 +- requirements.txt | 11 + 14 files changed, 1423 insertions(+), 3 deletions(-) create mode 100644 cosmos_transfer1/diffusion/datasets/augmentor_provider.py create mode 100644 cosmos_transfer1/diffusion/datasets/augmentors.py create mode 100644 cosmos_transfer1/diffusion/datasets/augmentors/basic_augmentors.py create mode 100644 cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py create mode 100644 cosmos_transfer1/diffusion/datasets/augmentors/text_transforms_for_video.py create mode 100644 cosmos_transfer1/diffusion/datasets/dataset_utils.py create mode 100644 cosmos_transfer1/diffusion/datasets/video_dataset.py diff --git a/cosmos_transfer1/diffusion/conditioner.py b/cosmos_transfer1/diffusion/conditioner.py index 82016cce..51bc8fbe 100644 --- a/cosmos_transfer1/diffusion/conditioner.py +++ b/cosmos_transfer1/diffusion/conditioner.py @@ -340,6 +340,7 @@ class BaseWithCtrlCondition(VideoExtendCondition): base_model: Optional[torch.nn.Module] = None hint_key: Optional[str] = None control_weight: Optional[float] = 1.0 + num_layers_to_use: Optional[int] = -1 class VideoConditionerWithCtrl(VideoExtendConditioner): diff --git a/cosmos_transfer1/diffusion/config/base/conditioner.py b/cosmos_transfer1/diffusion/config/base/conditioner.py index 6a52df75..ec86a375 100644 --- a/cosmos_transfer1/diffusion/config/base/conditioner.py +++ b/cosmos_transfer1/diffusion/config/base/conditioner.py @@ -74,6 +74,11 @@ def forward(self, element: torch.Tensor) -> Dict[str, torch.Tensor]: element = element.to(dtype=self.dtype) return {key: element} + def details(self) -> str: + key = self.output_key if self.output_key else self.input_key + return f"Output key: {key} \n\tDtype: {self.dtype}" + + @attrs.define(slots=False) class FPSConfig: diff --git a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py index f161f848..a6a65e70 100644 --- a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py +++ b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py @@ -94,6 +94,7 @@ def make_ctrlnet_config_7b_training( distributed_parallelism="ddp", logging_iter=200, max_iter=999_999_999, + timestamp_seed=True, ), model_parallel=dict( tensor_model_parallel_size=8, diff --git a/cosmos_transfer1/diffusion/config/transfer/registry.py b/cosmos_transfer1/diffusion/config/transfer/registry.py index e7b951ec..c9736f52 100644 --- a/cosmos_transfer1/diffusion/config/transfer/registry.py +++ b/cosmos_transfer1/diffusion/config/transfer/registry.py @@ -25,6 +25,9 @@ def register_experiment_ctrlnet(cs): + # TODO: maybe we should change the 'name' here; it's the dit-encoder for net_ctrl + # but current naming is the same as for the main 'net' group (which corresponds to the full DiT) + # that's defined in cosmos_transfer1/diffusion/config/registry.py. Isn't an error but could be confusing. cs.store(group="net_ctrl", package="model.net_ctrl", name="faditv2_7b", node=FADITV2EncoderConfig) cs.store(group="conditioner", package="model.conditioner", name="ctrlnet", node=BaseVideoConditionerWithCtrlConfig) diff --git a/cosmos_transfer1/diffusion/datasets/augmentor_provider.py b/cosmos_transfer1/diffusion/datasets/augmentor_provider.py new file mode 100644 index 00000000..f1953b7f --- /dev/null +++ b/cosmos_transfer1/diffusion/datasets/augmentor_provider.py @@ -0,0 +1,171 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_transfer1.utils.lazy_config import LazyCall as L +from cosmos_transfer1.diffusion.datasets.augmentors.merge_datadict import DataDictMerger +from cosmos_transfer1.diffusion.datasets.augmentors.control_input import ( + VIDEO_RES_SIZE_INFO, + AddControlInputComb, + AddControlInput, +) +from cosmos_transfer1.diffusion.datasets.augmentors.basic_augmentors import ( + ResizeLargestSideAspectPreserving, + ReflectionPadding, +) +from cosmos_transfer1.diffusion.datasets.augmentors.text_transforms_for_video import ( + TextTransformForVideo, +) +from cosmos_transfer1.diffusion.config.transfer.conditioner import ( + CTRL_HINT_KEYS, + CTRL_HINT_KEYS_COMB, +) +from cosmos_transfer1.diffusion.datasets.video_dataset import CTRL_AUG_KEYS +from cosmos_transfer1.diffusion.config.transfer.blurs import BlurAugmentorConfig + +AUGMENTOR_OPTIONS = {} + + +def augmentor_register(key): + def decorator(func): + AUGMENTOR_OPTIONS[key] = func + return func + + return decorator + + +@augmentor_register("video_basic_augmentor") +def get_video_augmentor( + resolution: str, + text_transform_input_keys: str, + append_fps_frames: str = False, + blur_config=None, +): + return { + "merge_datadict": L(DataDictMerger)( + input_keys=["video"], + output_keys=[ + "video", + "fps", + "num_frames", + "chunk_index", + "frame_start", + "frame_end", + "orig_num_frames", + ], + ), + "resize_largest_side_aspect_ratio_preserving": L( + ResizeLargestSideAspectPreserving + )( + input_keys=["video"], + args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + ), + "reflection_padding": L(ReflectionPadding)( + input_keys=["video"], + args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + ), + "text_transform": L(TextTransformForVideo)( + input_keys=text_transform_input_keys, + args={ + "t5_tokens": {"num": 512, "dim": 1024}, + "is_mask_all_ones": True, + }, + ), + } + + +""" +register all the video ctrlnet augmentors for data loading +""" +for hint_key in CTRL_HINT_KEYS: + + def get_video_ctrlnet_augmentor(hint_key, use_random=True): + def _get_video_ctrlnet_augmentor( + resolution: str, + text_transform_input_keys: str, + blur_config: BlurAugmentorConfig, + ): + if hint_key == "control_input_human_kpts": + add_control_input = L(AddControlInputComb)( + input_keys=["", "video"], + output_keys=[hint_key], + args={ + "comb": CTRL_HINT_KEYS_COMB[hint_key], + "use_openpose_format": True, + "kpt_thr": 0.6, + "human_kpt_line_width": 4, + }, + use_random=use_random, + blur_config=blur_config, + ) + elif hint_key in CTRL_HINT_KEYS_COMB: + add_control_input = L(AddControlInputComb)( + input_keys=["", "video"], + output_keys=[hint_key], + args={"comb": CTRL_HINT_KEYS_COMB[hint_key]}, + use_random=use_random, + blur_config=blur_config, + ) + else: + add_control_input = L(AddControlInput)( + input_keys=["", "video"], + output_keys=[hint_key], + use_random=use_random, + blur_config=blur_config, + ) + input_keys = ["video"] + output_keys = [ + "video", + "fps", + "num_frames", + "chunk_index", + "frame_start", + "frame_end", + "orig_num_frames", + ] + for key, value in CTRL_AUG_KEYS.items(): + if key in hint_key: + input_keys.append(value) + output_keys.append(value) + augmentation = { + "merge_datadict": L(DataDictMerger)( + input_keys=input_keys, + output_keys=output_keys, + ), + "add_control_input": add_control_input, + "resize_largest_side_aspect_ratio_preserving": L( + ResizeLargestSideAspectPreserving + )( + input_keys=["video", hint_key], + args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + ), + "reflection_padding": L(ReflectionPadding)( + input_keys=["video", hint_key], + args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + ), + "text_transform": L(TextTransformForVideo)( + input_keys=text_transform_input_keys, + args={ + "t5_tokens": {"num": 512, "dim": 1024}, + "is_mask_all_ones": True, + }, + ), + } + return augmentation + + return _get_video_ctrlnet_augmentor + + augmentor_register(f"video_ctrlnet_augmentor_{hint_key}")( + get_video_ctrlnet_augmentor(hint_key) + ) diff --git a/cosmos_transfer1/diffusion/datasets/augmentors.py b/cosmos_transfer1/diffusion/datasets/augmentors.py new file mode 100644 index 00000000..ee867e9c --- /dev/null +++ b/cosmos_transfer1/diffusion/datasets/augmentors.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import omegaconf +import torch +import torchvision.transforms.functional as transforms_F + +from cosmos_transfer1.diffusion.datasets.augmentors.control_input import Augmentor +from cosmos_transfer1.diffusion.datasets.dataset_utils import obtain_image_size, obtain_augmentation_size + + +class ReflectionPadding(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs reflection padding. This function also returns a padding mask. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + + assert self.args is not None, "Please specify args in augmentation" + if self.output_keys is None: + self.output_keys = self.input_keys + + # Obtain image and augmentation sizes + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + target_size = obtain_augmentation_size(data_dict, self.args) + + assert isinstance(target_size, (tuple, omegaconf.listconfig.ListConfig)), "Please specify target size as tuple" + target_w, target_h = target_size + + target_w = int(target_w) + target_h = int(target_h) + + # Calculate padding vals + padding_left = int((target_w - orig_w) / 2) + padding_right = target_w - orig_w - padding_left + padding_top = int((target_h - orig_h) / 2) + padding_bottom = target_h - orig_h - padding_top + padding_vals = [padding_left, padding_top, padding_right, padding_bottom] + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + if max(padding_vals[0], padding_vals[2]) >= orig_w or max(padding_vals[1], padding_vals[3]) >= orig_h: + # In this case, we can't perform reflection padding. This is because padding values + # are larger than the image size. So, perform edge padding instead. + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="edge") + else: + # Perform reflection padding + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="reflect") + + if out_key != inp_key: + del data_dict[inp_key] + + # Return padding_mask when padding is performed. + # Padding mask denotes which pixels are padded. + padding_mask = torch.ones((1, target_h, target_w)) + padding_mask[:, padding_top : (padding_top + orig_h), padding_left : (padding_left + orig_w)] = 0 + data_dict["padding_mask"] = padding_mask + data_dict["image_size"] = torch.tensor([target_h, target_w, orig_h, orig_w], dtype=torch.float) + + return data_dict + + + +class ResizeSmallestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to smaller side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=out_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to larger side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + + scaling_ratio = min(out_size / orig_w, out_size / orig_h) + target_size = [int(scaling_ratio * orig_h), int(scaling_ratio * orig_w)] + + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeSmallestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the smaller ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the smaller of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance( + img_size, (tuple, omegaconf.listconfig.ListConfig) + ), f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = max((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert ( + target_size[0] >= img_h and target_size[1] >= img_w + ), f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the larger ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the larger of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance( + img_size, (tuple, omegaconf.listconfig.ListConfig) + ), f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = min((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert ( + target_size[0] <= img_h and target_size[1] <= img_w + ), f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict diff --git a/cosmos_transfer1/diffusion/datasets/augmentors/basic_augmentors.py b/cosmos_transfer1/diffusion/datasets/augmentors/basic_augmentors.py new file mode 100644 index 00000000..ee867e9c --- /dev/null +++ b/cosmos_transfer1/diffusion/datasets/augmentors/basic_augmentors.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import omegaconf +import torch +import torchvision.transforms.functional as transforms_F + +from cosmos_transfer1.diffusion.datasets.augmentors.control_input import Augmentor +from cosmos_transfer1.diffusion.datasets.dataset_utils import obtain_image_size, obtain_augmentation_size + + +class ReflectionPadding(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs reflection padding. This function also returns a padding mask. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + + assert self.args is not None, "Please specify args in augmentation" + if self.output_keys is None: + self.output_keys = self.input_keys + + # Obtain image and augmentation sizes + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + target_size = obtain_augmentation_size(data_dict, self.args) + + assert isinstance(target_size, (tuple, omegaconf.listconfig.ListConfig)), "Please specify target size as tuple" + target_w, target_h = target_size + + target_w = int(target_w) + target_h = int(target_h) + + # Calculate padding vals + padding_left = int((target_w - orig_w) / 2) + padding_right = target_w - orig_w - padding_left + padding_top = int((target_h - orig_h) / 2) + padding_bottom = target_h - orig_h - padding_top + padding_vals = [padding_left, padding_top, padding_right, padding_bottom] + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + if max(padding_vals[0], padding_vals[2]) >= orig_w or max(padding_vals[1], padding_vals[3]) >= orig_h: + # In this case, we can't perform reflection padding. This is because padding values + # are larger than the image size. So, perform edge padding instead. + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="edge") + else: + # Perform reflection padding + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="reflect") + + if out_key != inp_key: + del data_dict[inp_key] + + # Return padding_mask when padding is performed. + # Padding mask denotes which pixels are padded. + padding_mask = torch.ones((1, target_h, target_w)) + padding_mask[:, padding_top : (padding_top + orig_h), padding_left : (padding_left + orig_w)] = 0 + data_dict["padding_mask"] = padding_mask + data_dict["image_size"] = torch.tensor([target_h, target_w, orig_h, orig_w], dtype=torch.float) + + return data_dict + + + +class ResizeSmallestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to smaller side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=out_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to larger side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + + scaling_ratio = min(out_size / orig_w, out_size / orig_h) + target_size = [int(scaling_ratio * orig_h), int(scaling_ratio * orig_w)] + + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeSmallestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the smaller ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the smaller of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance( + img_size, (tuple, omegaconf.listconfig.ListConfig) + ), f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = max((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert ( + target_size[0] >= img_h and target_size[1] >= img_w + ), f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the larger ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the larger of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance( + img_size, (tuple, omegaconf.listconfig.ListConfig) + ), f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = min((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert ( + target_size[0] <= img_h and target_size[1] <= img_w + ), f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict diff --git a/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py b/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py new file mode 100644 index 00000000..3b06db99 --- /dev/null +++ b/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py @@ -0,0 +1,54 @@ + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:..www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +from cosmos_transfer1.diffusion.datasets.augmentors.control_input import Augmentor +from cosmos_transfer1.utils import log + + +class DataDictMerger(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Merge the dictionary associated with the input keys into data_dict. Only keys in output_keys are merged. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict with dictionary associated with the input keys merged. + """ + for key in self.input_keys: + if key not in data_dict: + log.warning( + f"DataDictMerger dataloader error: missing {key}, {data_dict['__url__']}, {data_dict['__key__']}", + rank0_only=False, + ) + return None + key_dict = data_dict.pop(key) + if key == "depth" and "depth" in self.output_keys: + data_dict["depth"] = key_dict + if key == "human_annotation" and "human_annotation" in self.output_keys: + data_dict["human_annotation"] = key_dict + elif key == "segmentation" and "segmentation" in self.output_keys: + data_dict["segmentation"] = key_dict + for sub_key in key_dict: + if sub_key in self.output_keys and sub_key not in data_dict: + data_dict[sub_key] = key_dict[sub_key] + del key_dict + return data_dict diff --git a/cosmos_transfer1/diffusion/datasets/augmentors/text_transforms_for_video.py b/cosmos_transfer1/diffusion/datasets/augmentors/text_transforms_for_video.py new file mode 100644 index 00000000..1ad5c2cb --- /dev/null +++ b/cosmos_transfer1/diffusion/datasets/augmentors/text_transforms_for_video.py @@ -0,0 +1,136 @@ + +import random +from typing import Optional + +import numpy as np +import torch + +from cosmos_transfer1.diffusion.datasets.augmentors.control_input import Augmentor +from cosmos_transfer1.utils import log + + +def pad_and_resize( + arr_np: np.ndarray, ntokens: int, is_mask_all_ones: bool = False +) -> tuple[torch.Tensor, torch.Tensor]: + r"""Function for padding and resizing a numpy array. + Args: + arr (np.ndarray): Input array + ntokens (int): Number of output tokens after padding + is_mask_all_ones (bool): if true, set mask to ones + Returns: + arr_padded (torch.Tensor): Padded output tensor + mask (torch.Tensor): Padding mask + """ + + if isinstance(arr_np, np.ndarray): + arr = torch.from_numpy(arr_np) + elif isinstance(arr_np, torch.Tensor): + arr = arr_np.clone().detach() + else: + raise TypeError("`arr_np` should be a numpy array or torch tensor.") + embed_dim = arr.shape[1] + + arr_padded = torch.zeros(ntokens, embed_dim, device=arr.device, dtype=torch.float32) + + # If the input text is larger than num_text_tokens, clip it. + if arr.shape[0] > ntokens: + arr = arr[0:ntokens] + + mask = torch.LongTensor(ntokens).zero_() + if len(arr.shape) > 1: + mask[0 : arr.shape[0]] = 1 + + if len(arr.shape) > 1: + arr_padded[0 : arr.shape[0]] = arr + + if is_mask_all_ones: + mask.fill_(1) + + return arr_padded, mask + +class TextTransformForVideo(Augmentor): + def __init__(self, input_keys: dict, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs text transformation. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict with captions and t5 embeddings added + """ + data_source = data_dict["__url__"].meta.source + input_keys_by_source = self.input_keys[data_source] + + if "chunk_index" not in data_dict: + log.warning( + "Chunk_index is not in data_dict, set chunk_index to be 0. This should only happen for sampling." + ) + data_dict["chunk_index"] = 0 # this is for sampling only, whereas decoder is not loaded + try: + windows = data_dict[input_keys_by_source["ai_caption"]]["windows"] + n_windows = len(windows) + chunk_index = data_dict["chunk_index"] + + if chunk_index == n_windows: + # This will only happen when the number of captions does not match number of chunks due to re-transcoding the videos. + log.info( + f"Found {data_dict['orig_num_frames']} in video but captioning is done with videos of {windows[-1]['end_frame']} frames. This mismatch is due to video re-transcoding.", + rank0_only=False, + ) + chunk_index -= 1 + + selected_caption_window = windows[chunk_index] + except Exception as e: + log.warning( + f"TextTransform dataloader error: {data_dict['__url__']}, {data_dict['__key__']}, {data_dict['chunk_index']}\n error {e}", + rank0_only=False, + ) + return None + + try: + if "vila_caption" in selected_caption_window: + caption_type = "vila_caption" + elif "qwen_caption" in selected_caption_window: + caption_type = "qwen_caption" + else: + caption_type = random.choices(["long_caption", "short_caption"], weights=[0.95, 0.05], k=1)[0] + # TODO(hanzim): make probabilities configurable when we need it + data_dict["ai_caption"] = selected_caption_window[caption_type] + except Exception as e: + log.warning( + f"TextTransform dataloader error: {data_dict['__url__']}, {data_dict['__key__']}, {selected_caption_window}\n error {e}", + rank0_only=False, + ) + return None + + # TODO(hanzim): temp fix for samples with gt_caption = None + if data_dict["ai_caption"] is None: + data_dict["ai_caption"] = "" + del data_dict[input_keys_by_source["ai_caption"]] + + ai_caption_embedding_data = data_dict[input_keys_by_source["ai_caption_embedding"]] + try: + if caption_type in ["vila_caption", "qwen_caption"]: + t5_embedding = ai_caption_embedding_data[data_dict["chunk_index"]] + else: + t5_embedding = ai_caption_embedding_data[data_dict["chunk_index"]][ + caption_type.replace("_caption", "") + ] # t5_embedding is saved in {"short": array, "long": array} format + except Exception as e: + log.warning( + f"TextTransform dataloader error: {data_dict['__url__']}, {data_dict['__key__']}, {data_dict['chunk_index']}, {len(ai_caption_embedding_data)} \n error {e}", + rank0_only=False, + ) + return None + out_t5, out_t5_mask = pad_and_resize( + t5_embedding, + self.args["t5_tokens"]["num"], + is_mask_all_ones=self.args["is_mask_all_ones"], + ) + data_dict["t5_text_embeddings"] = out_t5 + data_dict["t5_text_mask"] = out_t5_mask + del data_dict[input_keys_by_source["ai_caption_embedding"]] + + return data_dict diff --git a/cosmos_transfer1/diffusion/datasets/dataset_utils.py b/cosmos_transfer1/diffusion/datasets/dataset_utils.py new file mode 100644 index 00000000..0969aaa1 --- /dev/null +++ b/cosmos_transfer1/diffusion/datasets/dataset_utils.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch +import torchvision.transforms.functional as transforms_F +from PIL import Image + + +def obtain_image_size(data_dict: dict, input_keys: list) -> tuple[int, int]: + r"""Function for obtaining the image size from the data dict. + + Args: + data_dict (dict): Input data dict + input_keys (list): List of input keys + Returns: + width (int): Width of the input image + height (int): Height of the input image + """ + + data1 = data_dict[input_keys[0]] + if isinstance(data1, Image.Image): + width, height = data1.size + elif isinstance(data1, torch.Tensor): + height, width = data1.size()[-2:] + else: + raise ValueError("data to random crop should be PIL Image or tensor") + + return width, height + + +def obtain_augmentation_size(data_dict: dict, augmentor_cfg: dict) -> Union[int, tuple]: + r"""Function for obtaining size of the augmentation. + When dealing with multi-aspect ratio dataloaders, we need to + find the augmentation size from the aspect ratio of the data. + + Args: + data_dict (dict): Input data dict + augmentor_cfg (dict): Augmentor config + Returns: + aug_size (int): Size of augmentation + """ + if "__url__" in data_dict and "aspect_ratio" in data_dict["__url__"].meta.opts: + aspect_ratio = data_dict["__url__"].meta.opts["aspect_ratio"] + aug_size = augmentor_cfg["size"][aspect_ratio] + else: # Non-webdataset format + aspect_ratio = data_dict["aspect_ratio"] + aug_size = augmentor_cfg["size"][aspect_ratio] + return aug_size + + +class Augmentor: + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + r"""Base augmentor class + + Args: + input_keys (list): List of input keys + output_keys (list): List of output keys + args (dict): Arguments associated with the augmentation + """ + self.input_keys = input_keys + self.output_keys = output_keys + self.args = args + + def __call__(self, *args: Any, **kwds: Any) -> Any: + raise ValueError("Augmentor not implemented") + + +class ResizeSmallestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the smaller ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the smaller of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_w, img_h = self.args["img_w"], self.args["img_h"] + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = max((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert ( + target_size[0] >= img_h and target_size[1] >= img_w + ), f"Resize error. orig {(orig_w, orig_h)} desire {(img_w, img_h)} compute {target_size}" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class CenterCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs center crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + assert ( + (self.args is not None) and ("img_w" in self.args) and ("img_h" in self.args) + ), "Please specify size in args" + + img_w, img_h = self.args["img_w"], self.args["img_h"] + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [img_h, img_w]) + + # We also add the aug params we use. This will be useful for other transforms + crop_x0 = (orig_w - img_w) // 2 + crop_y0 = (orig_h - img_h) // 2 + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": img_w, + "crop_h": img_h, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + data_dict["padding_mask"] = torch.zeros((1, cropping_params["crop_h"], cropping_params["crop_w"])) + return data_dict + + +class Normalize(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs data normalization. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + assert self.args is not None, "Please specify args" + + mean = self.args["mean"] + std = self.args["std"] + + for key in self.input_keys: + if isinstance(data_dict[key], torch.Tensor): + data_dict[key] = data_dict[key].to(dtype=torch.get_default_dtype()).div(255) + else: + data_dict[key] = transforms_F.to_tensor(data_dict[key]) # division by 255 is applied in to_tensor() + + data_dict[key] = transforms_F.normalize(tensor=data_dict[key], mean=mean, std=std) + return data_dict diff --git a/cosmos_transfer1/diffusion/datasets/video_dataset.py b/cosmos_transfer1/diffusion/datasets/video_dataset.py new file mode 100644 index 00000000..20fd8c4c --- /dev/null +++ b/cosmos_transfer1/diffusion/datasets/video_dataset.py @@ -0,0 +1,353 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import warnings +import traceback +from typing import List, Tuple, Dict, Optional +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field +from torchvision import transforms as T + +import numpy as np +import torch +from torch.utils.data import Dataset +from tqdm import tqdm +from decord import VideoReader, cpu +import pickle + +from cosmos_transfer1.diffusion.datasets.dataset_utils import ( + ResizeSmallestSideAspectPreserving, + CenterCrop, + Normalize, +) +from cosmos_transfer1.diffusion.training.datasets.dataset_utils import ( + ToTensorVideo, Resize_Preprocess +) +from cosmos_transfer1.diffusion.datasets.augmentor_provider import AUGMENTOR_OPTIONS +from cosmos_transfer1.diffusion.datasets.augmentors.control_input import ( + VIDEO_RES_SIZE_INFO, + AddControlInputComb, + AddControlInput, +) +# mappings between control types and corresponding sub-folders names in the data folder +CTRL_AUG_KEYS = { + "depth": "depth", + "seg": "seg", + "human_kpts": "human_kpts", +} + +# Map control types to their folder names and whether they need pre-stored data +CTRL_TYPE_INFO = { + "human_kpts": {"folder": "human_annotation", "needs_data": True}, + "depth": {"folder": "depth", "needs_data": True}, + "seg": {"folder": "seg", "needs_data": True}, + "canny": {"folder": None, "needs_data": False}, # Computed on-the-fly + "blur": {"folder": None, "needs_data": False}, # Computed on-the-fly + "upscale": {"folder": None, "needs_data": False} # Computed on-the-fly +} + + +@dataclass +class VideoDatasetWithCtrlConfig: # TODO (qianlim) not needed? + """Configuration for VideoDatasetWithCtrlAnnotations. + + Args: + dataset_name (str): Name of the dataset (e.g. "hdvila:control_input_human_kpts") + resolution (str): Data resolution ("256", "720", "1080") + num_video_frames (int): Number of frames to sample + video_decoder_name (str): Name of the video decoder + is_train (bool): Whether in training mode + use_fps_control (bool): Whether to use FPS control + min_fps_thres (int): Minimum FPS threshold when use_fps_control is True + max_fps_thres (int): Maximum FPS threshold when use_fps_control is True + dataset_resolution (str, optional): Minimum resolution to use in dataset + chunk_size (int, optional): Size of video chunks + rename_keys_src (list): Source keys to rename + rename_keys_target (list): Target keys to rename to + blur_config (dict, optional): Configuration for blur control + """ + dataset_name: str # e.g. "hdvila:control_input_human_kpts" + resolution: str + num_video_frames: int + is_train: bool + video_decoder_name: str = "video_decoder_w_controlled_fps" + use_fps_control: bool = False + min_fps_thres: int = 4 + max_fps_thres: int = 24 + dataset_resolution: Optional[str] = None + chunk_size: Optional[int] = None + rename_keys_src: List[str] = field(default_factory=list) + rename_keys_target: List[str] = field(default_factory=list) + blur_config: Optional[dict] = None + + +class VideoDatasetWithCtrlAnnotations(Dataset): + def __init__( + self, + dataset_dir, + sequence_interval, + num_frames, + video_size, + resolution, + start_frame_interval=1, + ctrl_types=None, + augmentor_name="video_basic_augmentor", + is_train=True + ): + """Dataset class for loading image-text-to-video generation data with control inputs. + + Args: + dataset_dir (str): Base path to the dataset directory + sequence_interval (int): Interval between sampled frames in a sequence + num_frames (int): Number of frames to load per sequence + video_size (list): Target size [H,W] for video frames + start_frame_interval (int): Interval for starting frames + ctrl_types (list): List of control types to load (e.g. ["human_kpts", "depth"]) + augmentor_name (str): Name of the augmentor to use + is_train (bool): Whether this is for training + """ + super().__init__() + self.dataset_dir = dataset_dir + self.start_frame_interval = start_frame_interval + self.sequence_interval = sequence_interval + self.sequence_length = num_frames + self.is_train = is_train + self.resolution = resolution + + assert resolution in VIDEO_RES_SIZE_INFO.keys(), "The provided resolution cannot be found in VIDEO_RES_SIZE_INFO." + + # Control input setup with file formats + self.ctrl_types = ctrl_types or [] + self.ctrl_config = { + "human_kpts": {"folder": "human_kpts", "format": "pkl"}, + "depth": {"folder": "depth", "format": "mp4"}, + "segmentation": {"folder": "seg", "format": "pkl"} + } + + # Set up directories + video_dir = os.path.join(self.dataset_dir, "videos") + self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")] + self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl") + print(f"{len(self.video_paths)} videos in total") + + # Initialize samples + self.samples = self._init_samples(self.video_paths) + self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0])) + print(f"{len(self.samples)} samples in total") + + # Set up preprocessing and augmentation + self.wrong_number = 0 + self.preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) + + if self.ctrl_types: + self.augmentor = AUGMENTOR_OPTIONS[augmentor_name]( + resolution=resolution, + text_transform_input_keys="", + append_fps_frames=False + ) + else: + self.augmentor = None + + def _init_samples(self, video_paths): + samples = [] + with ThreadPoolExecutor(32) as executor: + future_to_video_path = { + executor.submit(self._load_and_process_video_path, video_path): video_path + for video_path in video_paths + } + for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)): + samples.extend(future.result()) + return samples + + def _load_and_process_video_path(self, video_path): + # TODO (qianlim) add support for loading a chunck of N frames from loaded video and return the video and the frame ids + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + n_frames = len(vr) + + # Check if all required control files exist + ctrl_files_exist = True + video_name = os.path.basename(video_path).replace(".mp4", "") + for ctrl_type in self.ctrl_types: + if ctrl_type not in self.ctrl_config: + continue + ctrl_info = self.ctrl_config[ctrl_type] + ctrl_path = os.path.join( + self.dataset_dir, + ctrl_info["folder"], + f"{video_name}.{ctrl_info['format']}" + ) + if not os.path.exists(ctrl_path): + ctrl_files_exist = False + warnings.warn(f"Missing control file: {ctrl_path}") + break + + samples = [] + if not ctrl_files_exist: + return samples + + for frame_i in range(0, n_frames, self.start_frame_interval): + sample = dict() + sample["video_path"] = video_path + sample["t5_embedding_path"] = os.path.join( + self.t5_dir, + os.path.basename(video_path).replace(".mp4", ".pickle"), + ) + # Add control paths with their formats + sample["ctrl_paths"] = {} + for ctrl_type in self.ctrl_types: + if ctrl_type in self.ctrl_config: + ctrl_info = self.ctrl_config[ctrl_type] + sample["ctrl_paths"][ctrl_info["folder"]] = { + "path": os.path.join( + self.dataset_dir, + ctrl_info["folder"], + f"{video_name}.{ctrl_info['format']}" + ), + "format": ctrl_info["format"] + } + + sample["frame_ids"] = [] + curr_frame_i = frame_i + while True: + if curr_frame_i > (n_frames - 1): + break + sample["frame_ids"].append(curr_frame_i) + if len(sample["frame_ids"]) == self.sequence_length: + break + curr_frame_i += self.sequence_interval + if len(sample["frame_ids"]) == self.sequence_length: + samples.append(sample) + return samples + + def _load_control_data(self, sample): + """Load control data for the video clip.""" + data_dict = {} + frame_ids = sample["frame_ids"] + + for ctrl_folder, ctrl_info in sample["ctrl_paths"].items(): + try: + if ctrl_info["format"] == "pkl": + # Load pickle files (for human_kpts and segmentation) + with open(ctrl_info["path"], 'rb') as f: + ctrl_data = pickle.load(f) + data_dict[ctrl_folder] = ctrl_data + + elif ctrl_info["format"] == "mp4": + # Load video files (for depth) + vr = VideoReader(ctrl_info["path"], ctx=cpu(0)) + # Ensure the depth video has the same number of frames + assert len(vr) >= frame_ids[-1] + 1, \ + f"Depth video {ctrl_info['path']} has fewer frames than main video" + + # Load the corresponding frames + depth_frames = vr.get_batch(frame_ids).asnumpy() + depth_frames = torch.from_numpy(depth_frames).permute(0, 3, 1, 2) # [T,C,H,W] + + data_dict[ctrl_folder] = { + "video": depth_frames, + "frame_start": frame_ids[0], + "frame_end": frame_ids[-1], + "chunk_index": 0 # Required by some augmentors + } + + except Exception as e: + warnings.warn(f"Failed to load control data from {ctrl_info['path']}: {str(e)}") + return None + + return data_dict + + def _load_video(self, video_path, frame_ids): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + assert (np.array(frame_ids) < len(vr)).all() + assert (np.array(frame_ids) >= 0).all() + vr.seek(0) + frame_data = vr.get_batch(frame_ids).asnumpy() + try: + fps = vr.get_avg_fps() + except Exception: # failed to read FPS + fps = 24 + return frame_data, fps + + def _get_frames(self, video_path, frame_ids): + frames, fps = self._load_video(video_path, frame_ids) + frames = frames.astype(np.uint8) + frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w) + frames = self.preprocess(frames) + frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) + return frames, fps + + def __getitem__(self, index): + try: + sample = self.samples[index] + video_path = sample["video_path"] + frame_ids = sample["frame_ids"] + + data = dict() + + # Load video frames + video, fps = self._get_frames(video_path, frame_ids) + video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + + # Basic data + data["video"] = video + data["video_name"] = { + "video_path": video_path, + "t5_embedding_path": sample["t5_embedding_path"], + "start_frame_id": str(frame_ids[0]), + } + + # Load T5 embeddings + with open(sample["t5_embedding_path"], "rb") as f: + t5_embedding = pickle.load(f)[0] + data["t5_text_embeddings"] = torch.from_numpy(t5_embedding).cuda() + data["t5_text_mask"] = torch.ones(512, dtype=torch.int64).cuda() + + # Add metadata + data["fps"] = fps + data["frame_start"] = frame_ids[0] + data["frame_end"] = frame_ids[-1] + data["image_size"] = torch.tensor([704, 1280, 704, 1280]).cuda() + data["num_frames"] = self.sequence_length + data["padding_mask"] = torch.zeros(1, 704, 1280).cuda() + + if self.ctrl_types: + ctrl_data = self._load_control_data(sample) + if ctrl_data is None: # Control data loading failed + return self[np.random.randint(len(self.samples))] + data.update(ctrl_data) + + # Apply augmentations including control input processing + for aug_name, aug_fn in self.augmentor.items(): + data = aug_fn(data) + + return data + + except Exception: + warnings.warn( + f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + self.wrong_number += 1 + print(self.wrong_number) + return self[np.random.randint(len(self.samples))] + + def __len__(self): + return len(self.samples) + + def __str__(self): + return f"{len(self.video_paths)} samples from {self.dataset_dir}" \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/training/train.py b/cosmos_transfer1/diffusion/training/train.py index 40b02af2..ebf71be7 100644 --- a/cosmos_transfer1/diffusion/training/train.py +++ b/cosmos_transfer1/diffusion/training/train.py @@ -16,7 +16,7 @@ import argparse import importlib import os - +import time import torch.distributed as dist from loguru import logger as logging from omegaconf import OmegaConf @@ -59,6 +59,14 @@ def destroy_distributed(): def launch(config: Config, args: argparse.Namespace) -> None: # Check that the config is valid config.validate() + if config.trainer.timestamp_seed: # TODO (qianlim): check if this is set in the config yaml + # Get the current time in microseconds + current_time = int(time.time() * 1e6) + # Combine the current time with worker_id to ensure different seeds across workers + seed = current_time % (2**32) + config.trainer.seed = seed + log.critical(f"Changed Random Seed based on timestamp. {config.trainer.seed}") + # Freeze the config so developers don't change it during training. config.freeze() # type: ignore trainer = config.trainer.type(config) diff --git a/examples/post-training_cosmos_transfer_7b_edge.md b/examples/post-training_cosmos_transfer_7b_edge.md index 91564e63..63bb89ca 100644 --- a/examples/post-training_cosmos_transfer_7b_edge.md +++ b/examples/post-training_cosmos_transfer_7b_edge.md @@ -145,12 +145,14 @@ Training a VisControl or EdgeControl model is self-supervised: we apply blurs an #### 3. Post-train the Model -Run the following command to execute an example post-training job with the above data. +Run the following command to execute an example post-training job with the above data. The command below will output the detailed experiment config file at +`checkpoints/cosmos_transfer1_posttrain/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3/config.yaml`. ```bash export OUTPUT_ROOT=checkpoints # default value -torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/training/config/config.py --experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3 +torchrun --nproc_per_node=1 -m cosmos_transfer1.diffusion.training.train --dryrun --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3 ``` +Removing the `--dryrun` will start a real training job. checkpoints/cosmos_transfer1/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3/config.yaml This command will use ``cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py` to register experiments for all `hint_keys` (control modalities). diff --git a/requirements.txt b/requirements.txt index d1f781c0..7329f303 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,8 @@ attrs==25.1.0 av==14.2.0 better-profanity==0.7.0 +decord==0.6.0 +diffusers==0.32.2 einops==0.7.0 einx==0.1.3 huggingface-hub==0.29.2 @@ -28,16 +30,25 @@ mediapy==1.2.2 megatron-core==0.10.0 natsort==8.4.0 nltk==3.9.1 +numpy==1.26.4 +nvidia-ml-py==12.535.133 +omegaconf==2.3.0 opencv-contrib-python==4.10.0.84 +pandas==2.2.3 peft==0.14.0 pillow==10.4.0 +pynvml==12.0.0 +pyyaml==6.0.2 pycocotools retinaface-py==0.0.2 rtmlib==0.0.13 sam2==1.1.0 +safetensors==0.5.3 +scikit-image==0.25.2 sentencepiece==0.2.0 termcolor==2.5.0 torch==2.6.0 torchvision==0.21.0 transformers==4.49.0 +tqdm==4.66.5 vllm==0.8.0 From ee9a31b14b59885269e05affc8944cf6980c3abb Mon Sep 17 00:00:00 2001 From: Qianli Ma Date: Mon, 14 Apr 2025 19:22:42 -0700 Subject: [PATCH 04/10] feat: add example data class, add misc improvements to data loading and config, add script to convert ckpt to tp --- .../diffusion/config/base/data.py | 54 ++++ cosmos_transfer1/diffusion/config/registry.py | 3 + .../experiment/ctrl_7b_tp_121frames.py | 55 +--- .../config/training/registry_extra.py | 6 +- .../diffusion/config/transfer/registry.py | 9 +- .../diffusion/datasets/augmentor_provider.py | 24 +- .../datasets/augmentors/control_input.py | 4 +- .../datasets/augmentors/merge_datadict.py | 4 +- .../augmentors/text_transforms_for_video.py | 6 +- .../diffusion/datasets/dataset_utils.py | 2 +- ...dataset.py => example_transfer_dataset.py} | 260 ++++++++---------- .../training/datasets/dataset_video.py | 206 -------------- scripts/convert_ckpt_fsdp_to_tp.py | 129 +++++++++ 13 files changed, 331 insertions(+), 431 deletions(-) create mode 100644 cosmos_transfer1/diffusion/config/base/data.py rename cosmos_transfer1/diffusion/datasets/{video_dataset.py => example_transfer_dataset.py} (55%) delete mode 100644 cosmos_transfer1/diffusion/training/datasets/dataset_video.py create mode 100644 scripts/convert_ckpt_fsdp_to_tp.py diff --git a/cosmos_transfer1/diffusion/config/base/data.py b/cosmos_transfer1/diffusion/config/base/data.py new file mode 100644 index 00000000..283434a4 --- /dev/null +++ b/cosmos_transfer1/diffusion/config/base/data.py @@ -0,0 +1,54 @@ +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS +from cosmos_transfer1.diffusion.datasets.example_transfer_dataset import ( + ExampleTransferDataset, +) +from cosmos_transfer1.utils.lazy_config import LazyCall as L + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +def get_example_transfer_dataset(hint_key, is_train=True): + dataset = L(ExampleTransferDataset)( + dataset_dir="datasets/example_transfer_training_data", + chunk_size=256, + num_frames=121, + resolution="720", + hint_key=hint_key, + is_train=is_train, + ) + + return L(DataLoader)( + dataset=dataset, + sampler=L(get_sampler)(dataset=dataset), + batch_size=1, + drop_last=True, + ) + + +# NOTE 1: For customized post train: add your dataloader registration here. +# NOTE 2: The loop below simply registers a dataset for all hint_keys in CTRL_HINT_KEYS. The actual data might not exist. +def register_data_ctrlnet(cs): + for hint_key in CTRL_HINT_KEYS: + cs.store( + group="data_train", + package="dataloader_train", + name=f"example_transfer_train_data_{hint_key}", + node=get_example_transfer_dataset(hint_key=hint_key, is_train=True), + ) + cs.store( + group="data_val", + package="dataloader_val", + name=f"example_transfer_val_data_{hint_key}", + node=get_example_transfer_dataset(hint_key=hint_key, is_train=False), + ) diff --git a/cosmos_transfer1/diffusion/config/registry.py b/cosmos_transfer1/diffusion/config/registry.py index 7b264c30..6761c805 100644 --- a/cosmos_transfer1/diffusion/config/registry.py +++ b/cosmos_transfer1/diffusion/config/registry.py @@ -66,6 +66,9 @@ def register_tokenizer(cs): def register_configs(): + ''' + base model related registry + ''' cs = ConfigStore.instance() register_net(cs) diff --git a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py index a6a65e70..e2fc161c 100644 --- a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py +++ b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py @@ -27,32 +27,24 @@ from cosmos_transfer1.utils.lazy_config import LazyCall as L from cosmos_transfer1.utils.lazy_config import LazyDict -from cosmos_transfer1.diffusion.config.transfer.blurs import random_blur_config from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS_COMB from cosmos_transfer1.diffusion.training.models.model_ctrl import VideoDiffusionModelWithCtrl # this one has training support from cosmos_transfer1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT + cs = ConfigStore.instance() num_frames = 121 num_blocks = 28 num_control_blocks = 3 -# TODO (qianlim) add data config -def get_data_train_name(hint_key: str) -> str: - pass -def get_data_val_name(hint_key: str) -> str: - pass def make_ctrlnet_config_7b_training( hint_key: str = "control_input_canny", num_control_blocks: int = 3, ) -> LazyDict: - data_train = get_data_train_name(hint_key) - data_val = get_data_val_name(hint_key) - # Create the complete configuration in one step config = LazyDict( dict( @@ -67,8 +59,9 @@ def make_ctrlnet_config_7b_training( {"override /checkpoint": "local"}, {"override /ckpt_klass": "fast_tp"}, # - {"override /data_train": data_train}, - {"override /data_val": data_val}, + # data: register your own data at cosmos_transfer1/diffusion/config/base/data.py + {"override /data_train": f"example_transfer_train_data_{hint_key}"}, + {"override /data_val": f"example_transfer_val_data_{hint_key}"}, "_self_", ], job=dict( @@ -83,7 +76,10 @@ def make_ctrlnet_config_7b_training( eps=1e-10, ), checkpoint=dict( - load_path="checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control.pt", # modify as needed. Here we assume post-train our pre-trained VisControl model. + # Modify load_path as needed if you do post-training (fine-tuning). + # If training from scratch, leave it empty. + # Here we assume post-train our pre-trained VisControl model. + load_path="checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control.pt", broadcast_via_filesystem=True, save_iter=1000, load_training_state=False, @@ -111,8 +107,8 @@ def make_ctrlnet_config_7b_training( 160, ], base_load_from=dict( - load_path="checkpoints/nvidia/Cosmos-Transfer1-7B/base_model.pt", # modify as needed. This is the base model (that's frozen during training). - ), + load_path="checkpoints/nvidia/Cosmos-Transfer1-7B/base_model.pt", + ), # modify as needed. This is the base model ckpt (that's frozen during training). finetune_base_model=False, hint_mask=[True] * len(CTRL_HINT_KEYS_COMB[hint_key]), hint_dropout_rate=0.3, @@ -160,37 +156,6 @@ def make_ctrlnet_config_7b_training( f_max=[1.0], f_min=[1.0], ), - dataloader_val=dict( - dataset=dict( - resolution="720", - num_video_frames=num_frames, - ), - ), - dataloader_train=dict( - dataloaders=dict( - image_data=dict( - dataloader=dict( - batch_size=1, - dataset=dict( - resolution="720", - blur_config=random_blur_config, - ), - ), - ratio=0, # only use video data for training. - ), - video_data=dict( - dataloader=dict( - batch_size=1, - dataset=dict( - resolution="720", - num_video_frames=num_frames, - blur_config=random_blur_config, - ), - ), - ratio=1, - ), - ), - ), ) ) return config diff --git a/cosmos_transfer1/diffusion/config/training/registry_extra.py b/cosmos_transfer1/diffusion/config/training/registry_extra.py index 3769a299..1e83fb57 100644 --- a/cosmos_transfer1/diffusion/config/training/registry_extra.py +++ b/cosmos_transfer1/diffusion/config/training/registry_extra.py @@ -25,10 +25,8 @@ from cosmos_transfer1.diffusion.config.transfer.registry import register_experiment_ctrlnet +from cosmos_transfer1.diffusion.config.base.data import register_data_ctrlnet -# TODO (qianlim) add config / tutorial for mock data -def register_data_ctrlnet(cs): - pass def register_configs(): cs = ConfigStore.instance() @@ -40,5 +38,5 @@ def register_configs(): base_training_registry.register_configs() # following will register data, experiment, callbacks - # register_data_ctrlnet(cs) # Coming soon + register_data_ctrlnet(cs) register_experiment_ctrlnet(cs) diff --git a/cosmos_transfer1/diffusion/config/transfer/registry.py b/cosmos_transfer1/diffusion/config/transfer/registry.py index c9736f52..bf265c85 100644 --- a/cosmos_transfer1/diffusion/config/transfer/registry.py +++ b/cosmos_transfer1/diffusion/config/transfer/registry.py @@ -25,9 +25,12 @@ def register_experiment_ctrlnet(cs): - # TODO: maybe we should change the 'name' here; it's the dit-encoder for net_ctrl - # but current naming is the same as for the main 'net' group (which corresponds to the full DiT) - # that's defined in cosmos_transfer1/diffusion/config/registry.py. Isn't an error but could be confusing. + ''' + transfer model related registry: controlnet architecture, hint keys, etc. + ''' + # TODO: maybe we should change the registered 'name' (faditv2_7b) here; it's the dit-encoder for net_ctrl + # but current naming is the same as the full DiT in the main 'net' group that's defined + # in cosmos_transfer1/diffusion/config/registry.py. Isn't an error but could be confusing. cs.store(group="net_ctrl", package="model.net_ctrl", name="faditv2_7b", node=FADITV2EncoderConfig) cs.store(group="conditioner", package="model.conditioner", name="ctrlnet", node=BaseVideoConditionerWithCtrlConfig) diff --git a/cosmos_transfer1/diffusion/datasets/augmentor_provider.py b/cosmos_transfer1/diffusion/datasets/augmentor_provider.py index f1953b7f..7be95b23 100644 --- a/cosmos_transfer1/diffusion/datasets/augmentor_provider.py +++ b/cosmos_transfer1/diffusion/datasets/augmentor_provider.py @@ -24,16 +24,14 @@ ResizeLargestSideAspectPreserving, ReflectionPadding, ) -from cosmos_transfer1.diffusion.datasets.augmentors.text_transforms_for_video import ( - TextTransformForVideo, -) from cosmos_transfer1.diffusion.config.transfer.conditioner import ( CTRL_HINT_KEYS, CTRL_HINT_KEYS_COMB, ) -from cosmos_transfer1.diffusion.datasets.video_dataset import CTRL_AUG_KEYS +from cosmos_transfer1.diffusion.datasets.example_transfer_dataset import CTRL_AUG_KEYS from cosmos_transfer1.diffusion.config.transfer.blurs import BlurAugmentorConfig + AUGMENTOR_OPTIONS = {} @@ -48,7 +46,6 @@ def decorator(func): @augmentor_register("video_basic_augmentor") def get_video_augmentor( resolution: str, - text_transform_input_keys: str, append_fps_frames: str = False, blur_config=None, ): @@ -75,13 +72,6 @@ def get_video_augmentor( input_keys=["video"], args={"size": VIDEO_RES_SIZE_INFO[resolution]}, ), - "text_transform": L(TextTransformForVideo)( - input_keys=text_transform_input_keys, - args={ - "t5_tokens": {"num": 512, "dim": 1024}, - "is_mask_all_ones": True, - }, - ), } @@ -93,7 +83,6 @@ def get_video_augmentor( def get_video_ctrlnet_augmentor(hint_key, use_random=True): def _get_video_ctrlnet_augmentor( resolution: str, - text_transform_input_keys: str, blur_config: BlurAugmentorConfig, ): if hint_key == "control_input_human_kpts": @@ -143,7 +132,9 @@ def _get_video_ctrlnet_augmentor( input_keys=input_keys, output_keys=output_keys, ), + # this addes the control input tensor to the data dict "add_control_input": add_control_input, + # this resizes both the video and the control input to the model's required input size "resize_largest_side_aspect_ratio_preserving": L( ResizeLargestSideAspectPreserving )( @@ -154,13 +145,6 @@ def _get_video_ctrlnet_augmentor( input_keys=["video", hint_key], args={"size": VIDEO_RES_SIZE_INFO[resolution]}, ), - "text_transform": L(TextTransformForVideo)( - input_keys=text_transform_input_keys, - args={ - "t5_tokens": {"num": 512, "dim": 1024}, - "is_mask_all_ones": True, - }, - ), } return augmentation diff --git a/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py b/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py index 7c5c1d4a..2dbdee89 100644 --- a/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py +++ b/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py @@ -66,8 +66,6 @@ "16,9": (1920, 1056), "9,16": (1056, 1920), }, - # 1024; the video format does not support it, but here we match it with image resolution - "1024": {"1,1": (1024, 1024), "4,3": (1280, 1024), "3,4": (1024, 1280), "16,9": (1280, 768), "9,16": (768, 1280)}, "720": {"1,1": (960, 960), "4,3": (960, 704), "3,4": (704, 960), "16,9": (1280, 704), "9,16": (704, 1280)}, "512": {"1,1": (512, 512), "4,3": (640, 512), "3,4": (512, 640), "16,9": (640, 384), "9,16": (384, 640)}, "480": {"1,1": (480, 480), "4,3": (640, 480), "3,4": (480, 640), "16,9": (768, 432), "9,16": (432, 768)}, @@ -704,7 +702,7 @@ def __call__(self, data_dict: dict) -> dict: data_dict[self.output_keys[0]] = torch.from_numpy(np.zeros((3, h, w)).astype(np.uint8)) return data_dict - assert data_dict["chunk_index"] == data_dict["depth"]["chunk_index"] + # assert data_dict["chunk_index"] == data_dict["depth"]["chunk_index"] key_out = self.output_keys[0] depth = data_dict["depth"]["video"] data_dict[key_out] = depth diff --git a/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py b/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py index 3b06db99..bfc61897 100644 --- a/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py +++ b/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py @@ -43,8 +43,8 @@ def __call__(self, data_dict: dict) -> dict: key_dict = data_dict.pop(key) if key == "depth" and "depth" in self.output_keys: data_dict["depth"] = key_dict - if key == "human_annotation" and "human_annotation" in self.output_keys: - data_dict["human_annotation"] = key_dict + if key == "human_kpts" and "human_kpts" in self.output_keys: + data_dict["human_kpts"] = key_dict elif key == "segmentation" and "segmentation" in self.output_keys: data_dict["segmentation"] = key_dict for sub_key in key_dict: diff --git a/cosmos_transfer1/diffusion/datasets/augmentors/text_transforms_for_video.py b/cosmos_transfer1/diffusion/datasets/augmentors/text_transforms_for_video.py index 1ad5c2cb..841f956f 100644 --- a/cosmos_transfer1/diffusion/datasets/augmentors/text_transforms_for_video.py +++ b/cosmos_transfer1/diffusion/datasets/augmentors/text_transforms_for_video.py @@ -92,11 +92,8 @@ def __call__(self, data_dict: dict) -> dict: try: if "vila_caption" in selected_caption_window: caption_type = "vila_caption" - elif "qwen_caption" in selected_caption_window: - caption_type = "qwen_caption" else: caption_type = random.choices(["long_caption", "short_caption"], weights=[0.95, 0.05], k=1)[0] - # TODO(hanzim): make probabilities configurable when we need it data_dict["ai_caption"] = selected_caption_window[caption_type] except Exception as e: log.warning( @@ -105,14 +102,13 @@ def __call__(self, data_dict: dict) -> dict: ) return None - # TODO(hanzim): temp fix for samples with gt_caption = None if data_dict["ai_caption"] is None: data_dict["ai_caption"] = "" del data_dict[input_keys_by_source["ai_caption"]] ai_caption_embedding_data = data_dict[input_keys_by_source["ai_caption_embedding"]] try: - if caption_type in ["vila_caption", "qwen_caption"]: + if caption_type in ["vila_caption"]: t5_embedding = ai_caption_embedding_data[data_dict["chunk_index"]] else: t5_embedding = ai_caption_embedding_data[data_dict["chunk_index"]][ diff --git a/cosmos_transfer1/diffusion/datasets/dataset_utils.py b/cosmos_transfer1/diffusion/datasets/dataset_utils.py index 0969aaa1..5bd2c1d1 100644 --- a/cosmos_transfer1/diffusion/datasets/dataset_utils.py +++ b/cosmos_transfer1/diffusion/datasets/dataset_utils.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, Union import torch import torchvision.transforms.functional as transforms_F diff --git a/cosmos_transfer1/diffusion/datasets/video_dataset.py b/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py similarity index 55% rename from cosmos_transfer1/diffusion/datasets/video_dataset.py rename to cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py index 20fd8c4c..6ee69eb6 100644 --- a/cosmos_transfer1/diffusion/datasets/video_dataset.py +++ b/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Run this command to interactively debug: +PYTHONPATH=. python cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py +""" + import os import warnings import traceback -from typing import List, Tuple, Dict, Optional from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import dataclass, field -from torchvision import transforms as T import numpy as np import torch @@ -28,114 +30,64 @@ from decord import VideoReader, cpu import pickle -from cosmos_transfer1.diffusion.datasets.dataset_utils import ( - ResizeSmallestSideAspectPreserving, - CenterCrop, - Normalize, -) -from cosmos_transfer1.diffusion.training.datasets.dataset_utils import ( - ToTensorVideo, Resize_Preprocess -) from cosmos_transfer1.diffusion.datasets.augmentor_provider import AUGMENTOR_OPTIONS -from cosmos_transfer1.diffusion.datasets.augmentors.control_input import ( - VIDEO_RES_SIZE_INFO, - AddControlInputComb, - AddControlInput, -) -# mappings between control types and corresponding sub-folders names in the data folder +from cosmos_transfer1.diffusion.datasets.augmentors.control_input import VIDEO_RES_SIZE_INFO + + CTRL_AUG_KEYS = { "depth": "depth", - "seg": "seg", + "seg": "segmentation", "human_kpts": "human_kpts", } -# Map control types to their folder names and whether they need pre-stored data +# mappings between control types and corresponding sub-folders names in the data folder CTRL_TYPE_INFO = { - "human_kpts": {"folder": "human_annotation", "needs_data": True}, - "depth": {"folder": "depth", "needs_data": True}, - "seg": {"folder": "seg", "needs_data": True}, - "canny": {"folder": None, "needs_data": False}, # Computed on-the-fly - "blur": {"folder": None, "needs_data": False}, # Computed on-the-fly - "upscale": {"folder": None, "needs_data": False} # Computed on-the-fly + "human_kpts": {"folder": "human_kpts", "format": "pickle", "data_dict_key": "human_kpts"}, + "depth": {"folder": "depth", "format": "mp4", "data_dict_key": "depth"}, + "seg": {"folder": "seg", "format": "pickle", "data_dict_key": "segmentation"}, + "edge": {"folder": None}, # Canny edge, computed on-the-fly + "vis": {"folder": None}, # Blur, computed on-the-fly + "upscale": {"folder": None} # Computed on-the-fly } -@dataclass -class VideoDatasetWithCtrlConfig: # TODO (qianlim) not needed? - """Configuration for VideoDatasetWithCtrlAnnotations. - - Args: - dataset_name (str): Name of the dataset (e.g. "hdvila:control_input_human_kpts") - resolution (str): Data resolution ("256", "720", "1080") - num_video_frames (int): Number of frames to sample - video_decoder_name (str): Name of the video decoder - is_train (bool): Whether in training mode - use_fps_control (bool): Whether to use FPS control - min_fps_thres (int): Minimum FPS threshold when use_fps_control is True - max_fps_thres (int): Maximum FPS threshold when use_fps_control is True - dataset_resolution (str, optional): Minimum resolution to use in dataset - chunk_size (int, optional): Size of video chunks - rename_keys_src (list): Source keys to rename - rename_keys_target (list): Target keys to rename to - blur_config (dict, optional): Configuration for blur control - """ - dataset_name: str # e.g. "hdvila:control_input_human_kpts" - resolution: str - num_video_frames: int - is_train: bool - video_decoder_name: str = "video_decoder_w_controlled_fps" - use_fps_control: bool = False - min_fps_thres: int = 4 - max_fps_thres: int = 24 - dataset_resolution: Optional[str] = None - chunk_size: Optional[int] = None - rename_keys_src: List[str] = field(default_factory=list) - rename_keys_target: List[str] = field(default_factory=list) - blur_config: Optional[dict] = None - - -class VideoDatasetWithCtrlAnnotations(Dataset): +class ExampleTransferDataset(Dataset): def __init__( self, dataset_dir, - sequence_interval, + chunk_size, num_frames, - video_size, resolution, start_frame_interval=1, - ctrl_types=None, - augmentor_name="video_basic_augmentor", + hint_key="control_input_vis", + # augmentor_name="video_basic_augmentor", is_train=True ): - """Dataset class for loading image-text-to-video generation data with control inputs. + """Dataset class for loading video-text-to-video generation data with control inputs. Args: dataset_dir (str): Base path to the dataset directory - sequence_interval (int): Interval between sampled frames in a sequence + chunk_size (int): Interval between sampled frames in a sequence. num_frames (int): Number of frames to load per sequence - video_size (list): Target size [H,W] for video frames + resolution (str): resolution of the target video size start_frame_interval (int): Interval for starting frames - ctrl_types (list): List of control types to load (e.g. ["human_kpts", "depth"]) - augmentor_name (str): Name of the augmentor to use + hint_key (str): The hint key for loading the correct control input data modality is_train (bool): Whether this is for training + + NOTE: in our example dataset we do not have a validation dataset. The is_train flag is kept here for customized configuration. """ super().__init__() self.dataset_dir = dataset_dir self.start_frame_interval = start_frame_interval - self.sequence_interval = sequence_interval + self.chunk_size = chunk_size self.sequence_length = num_frames self.is_train = is_train self.resolution = resolution - assert resolution in VIDEO_RES_SIZE_INFO.keys(), "The provided resolution cannot be found in VIDEO_RES_SIZE_INFO." # Control input setup with file formats - self.ctrl_types = ctrl_types or [] - self.ctrl_config = { - "human_kpts": {"folder": "human_kpts", "format": "pkl"}, - "depth": {"folder": "depth", "format": "mp4"}, - "segmentation": {"folder": "seg", "format": "pkl"} - } + self.ctrl_type = hint_key.lstrip("control_input_") + self.ctrl_data_pth_config = CTRL_TYPE_INFO[self.ctrl_type] # Set up directories video_dir = os.path.join(self.dataset_dir, "videos") @@ -150,16 +102,14 @@ def __init__( # Set up preprocessing and augmentation self.wrong_number = 0 - self.preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) - if self.ctrl_types: - self.augmentor = AUGMENTOR_OPTIONS[augmentor_name]( - resolution=resolution, - text_transform_input_keys="", - append_fps_frames=False - ) - else: - self.augmentor = None + augmentor_name = f"video_ctrlnet_augmentor_{hint_key}" + # The augmentor will process the 'raw' control input data to the tensor, + # add it to the data dict, and resize both the video and the control input to the model's required input size + self.augmentor = AUGMENTOR_OPTIONS[augmentor_name]( + resolution=resolution, + append_fps_frames=False + ) def _init_samples(self, video_paths): samples = [] @@ -173,26 +123,25 @@ def _init_samples(self, video_paths): return samples def _load_and_process_video_path(self, video_path): - # TODO (qianlim) add support for loading a chunck of N frames from loaded video and return the video and the frame ids vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) n_frames = len(vr) # Check if all required control files exist ctrl_files_exist = True video_name = os.path.basename(video_path).replace(".mp4", "") - for ctrl_type in self.ctrl_types: - if ctrl_type not in self.ctrl_config: - continue - ctrl_info = self.ctrl_config[ctrl_type] + + # load control input file if needed + if self.ctrl_data_pth_config["folder"] is not None: ctrl_path = os.path.join( self.dataset_dir, - ctrl_info["folder"], - f"{video_name}.{ctrl_info['format']}" + self.ctrl_data_pth_config["folder"], + f"{video_name}.{self.ctrl_data_pth_config['format']}" ) if not os.path.exists(ctrl_path): ctrl_files_exist = False - warnings.warn(f"Missing control file: {ctrl_path}") - break + warnings.warn(f"Missing control input file: {ctrl_path}") + else: + ctrl_files_exist = True samples = [] if not ctrl_files_exist: @@ -205,21 +154,18 @@ def _load_and_process_video_path(self, video_path): self.t5_dir, os.path.basename(video_path).replace(".mp4", ".pickle"), ) - # Add control paths with their formats - sample["ctrl_paths"] = {} - for ctrl_type in self.ctrl_types: - if ctrl_type in self.ctrl_config: - ctrl_info = self.ctrl_config[ctrl_type] - sample["ctrl_paths"][ctrl_info["folder"]] = { - "path": os.path.join( - self.dataset_dir, - ctrl_info["folder"], - f"{video_name}.{ctrl_info['format']}" - ), - "format": ctrl_info["format"] - } + + if self.ctrl_data_pth_config["folder"] is not None: + sample["ctrl_path"] = os.path.join( + self.dataset_dir, + self.ctrl_data_pth_config["folder"], + f"{video_name}.{self.ctrl_data_pth_config['format']}" + ) + else: + sample["ctrl_path"] = None sample["frame_ids"] = [] + sample["chunk_index"] = -1 curr_frame_i = frame_i while True: if curr_frame_i > (n_frames - 1): @@ -227,8 +173,9 @@ def _load_and_process_video_path(self, video_path): sample["frame_ids"].append(curr_frame_i) if len(sample["frame_ids"]) == self.sequence_length: break - curr_frame_i += self.sequence_interval + curr_frame_i += self.chunk_size if len(sample["frame_ids"]) == self.sequence_length: + sample["chunk_index"] += 1 samples.append(sample) return samples @@ -236,36 +183,36 @@ def _load_control_data(self, sample): """Load control data for the video clip.""" data_dict = {} frame_ids = sample["frame_ids"] - - for ctrl_folder, ctrl_info in sample["ctrl_paths"].items(): - try: - if ctrl_info["format"] == "pkl": - # Load pickle files (for human_kpts and segmentation) - with open(ctrl_info["path"], 'rb') as f: - ctrl_data = pickle.load(f) - data_dict[ctrl_folder] = ctrl_data - - elif ctrl_info["format"] == "mp4": - # Load video files (for depth) - vr = VideoReader(ctrl_info["path"], ctx=cpu(0)) - # Ensure the depth video has the same number of frames - assert len(vr) >= frame_ids[-1] + 1, \ - f"Depth video {ctrl_info['path']} has fewer frames than main video" - - # Load the corresponding frames - depth_frames = vr.get_batch(frame_ids).asnumpy() - depth_frames = torch.from_numpy(depth_frames).permute(0, 3, 1, 2) # [T,C,H,W] - - data_dict[ctrl_folder] = { - "video": depth_frames, - "frame_start": frame_ids[0], - "frame_end": frame_ids[-1], - "chunk_index": 0 # Required by some augmentors - } - - except Exception as e: - warnings.warn(f"Failed to load control data from {ctrl_info['path']}: {str(e)}") - return None + ctrl_path = sample["ctrl_path"] + try: + if self.ctrl_type == "seg": + with open(ctrl_path, 'rb') as f: + ctrl_data = pickle.load(f) + # key should match line 982 at cosmos_transfer1/diffusion/datasets/augmentors/control_input.py + data_dict["segmentation"] = ctrl_data + elif self.ctrl_type == "human_kpts": + with open(ctrl_path, 'rb') as f: + ctrl_data = pickle.load(f) + data_dict["human_kpts"] = ctrl_data + elif self.ctrl_type == "depth": + vr = VideoReader(ctrl_path, ctx=cpu(0)) + # Ensure the depth video has the same number of frames + assert len(vr) >= frame_ids[-1] + 1, \ + f"Depth video {ctrl_data} has fewer frames than main video" + + # Load the corresponding frames + depth_frames = vr.get_batch(frame_ids).asnumpy() + depth_frames = torch.from_numpy(depth_frames).permute(0, 3, 1, 2) # [T,C,H,W] + data_dict["depth"] = { + "video": depth_frames, + "frame_start": frame_ids[0], + "frame_end": frame_ids[-1], + "chunk_index": sample["chunk_index"] + } + + except Exception as e: + warnings.warn(f"Failed to load control data from {ctrl_data}: {str(e)}") + return None return data_dict @@ -319,13 +266,14 @@ def __getitem__(self, index): data["fps"] = fps data["frame_start"] = frame_ids[0] data["frame_end"] = frame_ids[-1] + data["chunk_index"] = sample["chunk_index"] data["image_size"] = torch.tensor([704, 1280, 704, 1280]).cuda() data["num_frames"] = self.sequence_length data["padding_mask"] = torch.zeros(1, 704, 1280).cuda() - if self.ctrl_types: + if self.ctrl_type: ctrl_data = self._load_control_data(sample) - if ctrl_data is None: # Control data loading failed + if ctrl_data is None: # Control data loading failed, discard this sample and reload another sample return self[np.random.randint(len(self.samples))] data.update(ctrl_data) @@ -350,4 +298,32 @@ def __len__(self): return len(self.samples) def __str__(self): - return f"{len(self.video_paths)} samples from {self.dataset_dir}" \ No newline at end of file + return f"{len(self.video_paths)} samples from {self.dataset_dir}" + + +if __name__ == "__main__": + dataset = ExampleTransferDataset( + dataset_dir="assets/example_transfer_training_data/", + hint_key="control_input_seg", + chunk_size=1, + num_frames=121, + resolution="720", + # augmentor_name="video_basic_augmentor", + is_train=True + ) + + indices = [0, 13, 200, -1] + for idx in indices: + data = dataset[idx] + print( + ( + f"{idx=} " + f"{data['video'].sum()=}\n" + f"{data['video'].shape=}\n" + f"{data['depth']['video'].sum()=}\n" + f"{data['depth']['video'].shape=}\n" + f"{data['video_name']=}\n" + f"{data['t5_text_embeddings'].shape=}\n" + "---" + ) + ) diff --git a/cosmos_transfer1/diffusion/training/datasets/dataset_video.py b/cosmos_transfer1/diffusion/training/datasets/dataset_video.py deleted file mode 100644 index 6635ace6..00000000 --- a/cosmos_transfer1/diffusion/training/datasets/dataset_video.py +++ /dev/null @@ -1,206 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Run this command to interactively debug: -PYTHONPATH=. python cosmos_transfer1/diffusion/training/datasets/dataset_gear.py - -Adapted from: -https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py -""" - -import os -import pickle -import traceback -import warnings -from concurrent.futures import ThreadPoolExecutor, as_completed - -import numpy as np -import torch -from decord import VideoReader, cpu -from torch.utils.data import Dataset -from torchvision import transforms as T -from tqdm import tqdm - -from cosmos_transfer1.diffusion.training.datasets.dataset_utils import Resize_Preprocess, ToTensorVideo - - -class Dataset(Dataset): - def __init__( - self, - dataset_dir, - sequence_interval, - num_frames, - video_size, - start_frame_interval=1, - ): - """Dataset class for loading image-text-to-video generation data. - - Args: - dataset_dir (str): Base path to the dataset directory - sequence_interval (int): Interval between sampled frames in a sequence - num_frames (int): Number of frames to load per sequence - video_size (list): Target size [H,W] for video frames - - Returns dict with: - - video: RGB frames tensor [T,C,H,W] - - video_name: Dict with episode/frame metadata - """ - - super().__init__() - self.dataset_dir = dataset_dir - self.start_frame_interval = start_frame_interval - self.sequence_interval = sequence_interval - self.sequence_length = num_frames - - video_dir = os.path.join(self.dataset_dir, "videos") - self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")] - # print(f"{len(self.video_paths)} trajectories in total") - print(f"{len(self.video_paths)} videos in total") - - # self.t5_dir = os.path.join(self.dataset_dir, "labels") - self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl") - self.samples = self._init_samples(self.video_paths) - self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0])) - print(f"{len(self.samples)} samples in total") - self.wrong_number = 0 - self.preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) - - def __str__(self): - return f"{len(self.video_paths)} samples from {self.dataset_dir}" - - def _init_samples(self, video_paths): - samples = [] - with ThreadPoolExecutor(32) as executor: - future_to_video_path = { - executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths - } - for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)): - samples.extend(future.result()) - return samples - - def _load_and_process_video_path(self, video_path): - vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) - n_frames = len(vr) - - samples = [] - for frame_i in range(0, n_frames, self.start_frame_interval): - sample = dict() - sample["video_path"] = video_path - sample["t5_embedding_path"] = os.path.join( - # self.t5_dir, os.path.basename(video_path).replace(".mp4", ".npy") - self.t5_dir, - os.path.basename(video_path).replace(".mp4", ".pickle"), - ) - sample["frame_ids"] = [] - curr_frame_i = frame_i - while True: - if curr_frame_i > (n_frames - 1): - break - sample["frame_ids"].append(curr_frame_i) - if len(sample["frame_ids"]) == self.sequence_length: - break - curr_frame_i += self.sequence_interval - # make sure there are sequence_length number of frames - if len(sample["frame_ids"]) == self.sequence_length: - samples.append(sample) - return samples - - def __len__(self): - return len(self.samples) - - def _load_video(self, video_path, frame_ids): - vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) - assert (np.array(frame_ids) < len(vr)).all() - assert (np.array(frame_ids) >= 0).all() - vr.seek(0) - frame_data = vr.get_batch(frame_ids).asnumpy() - try: - fps = vr.get_avg_fps() - except Exception: # failed to read FPS - fps = 24 - return frame_data, fps - - def _get_frames(self, video_path, frame_ids): - frames, fps = self._load_video(video_path, frame_ids) - frames = frames.astype(np.uint8) - frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w) - frames = self.preprocess(frames) - frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) - return frames, fps - - def __getitem__(self, index): - try: - sample = self.samples[index] - video_path = sample["video_path"] - frame_ids = sample["frame_ids"] - - data = dict() - - video, fps = self._get_frames(video_path, frame_ids) - video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] - data["video"] = video - data["video_name"] = { - "video_path": video_path, - "t5_embedding_path": sample["t5_embedding_path"], - "start_frame_id": str(frame_ids[0]), - } - - # Just add these to fit the interface - # t5_embedding = np.load(sample["t5_embedding_path"])[0] - with open(sample["t5_embedding_path"], "rb") as f: - t5_embedding = pickle.load(f)[0] - - data["t5_text_embeddings"] = torch.from_numpy(t5_embedding).cuda() - data["t5_text_mask"] = torch.ones(512, dtype=torch.int64).cuda() - data["fps"] = fps - data["image_size"] = torch.tensor([704, 1280, 704, 1280]).cuda() - data["num_frames"] = self.sequence_length - data["padding_mask"] = torch.zeros(1, 704, 1280).cuda() - - return data - except Exception: - warnings.warn( - f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped " - f"(by randomly sampling another sample in the same dataset)." - ) - warnings.warn("FULL TRACEBACK:") - warnings.warn(traceback.format_exc()) - self.wrong_number += 1 - print(self.wrong_number) - return self[np.random.randint(len(self.samples))] - - -if __name__ == "__main__": - dataset = Dataset( - dataset_dir="assets/example_training_data/", - sequence_interval=1, - num_frames=57, - video_size=[240, 360], - ) - - indices = [0, 13, 200, -1] - for idx in indices: - data = dataset[idx] - print( - ( - f"{idx=} " - f"{data['video'].sum()=}\n" - f"{data['video'].shape=}\n" - f"{data['video_name']=}\n" - f"{data['t5_text_embeddings'].shape=}\n" - "---" - ) - ) diff --git a/scripts/convert_ckpt_fsdp_to_tp.py b/scripts/convert_ckpt_fsdp_to_tp.py new file mode 100644 index 00000000..50492cf3 --- /dev/null +++ b/scripts/convert_ckpt_fsdp_to_tp.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +''' +Converting an FSDP checkpoint to a TP checkpoint. +''' +import sys +import os +import torch +from tqdm import tqdm +from collections import OrderedDict +from typing import Dict, Any, List + +from cosmos_transfer1.utils.easy_io import easy_io + + +TP_SIZE = 8 + +def is_column(key: str) -> bool: + """Check if the given key corresponds to a column-parallel parameter.""" + return ( + key.endswith("to_q.0.weight") + or key.endswith("to_k.0.weight") + or key.endswith("to_v.0.weight") + or key.endswith("block.layer1.weight") + ) + + +def is_row(key: str) -> bool: + """Check if the given key corresponds to a row-parallel parameter.""" + return key.endswith("to_out.0.weight") or key.endswith("block.layer2.weight") + + +def native_to_tp(reg_state_dict: Dict[str, Any], tp_size: int) -> List[OrderedDict]: + """Convert a regular state dict to tensor parallel state dicts. + + Args: + reg_state_dict: The regular state dictionary. + tp_size: The number of tensor parallel partitions. + + Returns: + A list of OrderedDicts, each representing a tensor parallel partition. + """ + tp_state_dict = [OrderedDict() for _ in range(tp_size)] + + for key, value in reg_state_dict.items(): + if key.endswith("_extra_state"): + continue + + if is_column(key): + for i, item in enumerate(value.chunk(tp_size, dim=0)): + tp_state_dict[i][key] = item + elif is_row(key): + for i, item in enumerate(value.chunk(tp_size, dim=1)): + tp_state_dict[i][key] = item + else: + for i in range(tp_size): + tp_state_dict[i][key] = value + + return tp_state_dict + + +def convert_fsdp_to_tp(path_in: str, path_out: str) -> None: + """Convert an FSDP checkpoint to TP format. + + Args: + path_in: Path to input checkpoint (without _reg_model.pt suffix) + path_out: Path for output checkpoint (without _mp_X.pt suffix) + tp_size: Number of tensor parallel partitions + verbose: Whether to show progress bar + + Raises: + FileNotFoundError: If input checkpoint doesn't exist + ValueError: If paths are invalid or tp_size <= 0 + RuntimeError: For other conversion errors + """ + try: + native_ckpt = torch.load( + path_in, + map_location=torch.device("cpu"), + ) + state_dicts = native_to_tp(native_ckpt, TP_SIZE) + except FileNotFoundError: + raise FileNotFoundError(f"Checkpoint file {path_in} not found") + except Exception as e: + raise RuntimeError(f"Error loading checkpoint: {str(e)}") + + for i in tqdm(range(TP_SIZE)): + state_dict = {"model": state_dicts[i]} + easy_io.dump(state_dict, f"{path_out}_mp_{i}.pt") + + +if __name__ == "__main__": + ''' + Example usage: converting a viscontrol model to a TP checkpoint. + + Command: + python convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control.pt + + This will save the Tensor Parallel (TP) checkpoints as 8 files in the same directory: + checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_tp_mp_0.pt + ... + checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_tp_mp_7.pt + ''' + if len(sys.argv) != 2: + print("Usage: python convert_ckpt_fsdp_to_tp.py ") + print("Example: python convert_ckpt_fsdp_to_tp.py checkpoints/model.pt") + sys.exit(1) + + checkpoint_path = sys.argv[1] + out_tp_checkpoint_path = os.path.basename(checkpoint_path).replace(".pt", "") + try: + convert_fsdp_to_tp(checkpoint_path, out_tp_checkpoint_path) + print("Conversion completed successfully!") + except Exception as e: + print(f"Error during conversion: {str(e)}") + sys.exit(1) From 5149ca68b6635716aaf91ed821f31a08d2903e98 Mon Sep 17 00:00:00 2001 From: Qianli Ma Date: Tue, 15 Apr 2025 00:27:41 -0700 Subject: [PATCH 05/10] fix: fix conflict in DiTEncoder --- .../networks/general_dit_ctrl_enc.py | 105 ++++++------------ 1 file changed, 33 insertions(+), 72 deletions(-) diff --git a/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py b/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py index 98557d96..faeac7e7 100644 --- a/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py +++ b/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py @@ -252,8 +252,9 @@ def forward( outs = {} - # If also training base model, sometimes drop the controlnet branch to only train base branch. - # This is to prevent the network become dependent on controlnet branch and make control weight useless. + # (Experimental, not used in the released model) if also training base model, sometimes drop the + # controlnet branch to only train base branch. This is to prevent the network become dependent on + # controlnet branch and make control weight useless. is_training = torch.is_grad_enabled() is_training_base_model = any(p.requires_grad for p in base_model.parameters()) if is_training and is_training_base_model: @@ -266,17 +267,9 @@ def forward( coin_flip = 1 num_control_blocks = self.layer_mask.index(True) - if self.random_drop_control_blocks: - if is_training: # Use a random number of layers during training. - num_layers_to_use = np.random.randint(num_control_blocks) + 1 - elif num_layers_to_use == -1: # Evaluate using all the layers. - num_layers_to_use = num_control_blocks - else: # Use the specified number of layers during inference. - pass - else: # Use all of the layers. - num_layers_to_use = num_control_blocks + num_layers_to_use = num_control_blocks control_gate_per_layer = [i < num_layers_to_use for i in range(num_control_blocks)] - + if isinstance(control_weight, torch.Tensor): if control_weight.ndim == 0: # Single scalar tensor control_weight = [float(control_weight)] * len(guided_hints) @@ -287,7 +280,6 @@ def forward( else: control_weight = [control_weight] * len(guided_hints) - # max_norm = {} x_before_blocks = x.clone() for i, guided_hint in enumerate(guided_hints): x = x_before_blocks @@ -337,34 +329,28 @@ def forward( self.affline_scale_log_info = affline_scale_log_info self.affline_emb = affline_emb_B_D - if self.blocks["block0"].x_format == "THWBD": - x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group ) - - if self.sequence_parallel: - tp_group = parallel_state.get_tensor_model_parallel_group() - # Sequence parallel requires the input tensor to be scattered along the first dimension. - assert self.block_config == "FA-CA-MLP" # Only support this block config for now - T, H, W, B, D = x.shape - # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer - x = x.view(T * H * W, 1, 1, B, D) - assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 - x = scatter_along_first_dim(x, tp_group) - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( - T * H * W, 1, 1, B, D - ) - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group - ) - - elif self.blocks["block0"].x_format == "BTHWD": - x = x_B_T_H_W_D - else: - raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") for idx, (name, block) in enumerate(blocks.items()): assert ( @@ -388,40 +374,15 @@ def forward( hint_val = zero_blocks[name](x) * control_weight[i] * coin_flip * gate else: # Spatial-temporal weights [num_controls, B, 1, T, H, W] control_feat = zero_blocks[name](x) - # Get current feature dimensions - if self.blocks["block0"].x_format == "THWBD": - weight_map = control_weight[i] # [B, 1, T, H, W] - - if weight_map.shape[2:5] != (T, H, W): - assert weight_map.shape[2] == 8 * (T - 1) + 1 - weight_map_i = [ - torch.nn.functional.interpolate( - weight_map[:, :, :1, :, :], - size=(1, H, W), - mode="trilinear", - align_corners=False, - ) - ] - for wi in range(1, weight_map.shape[2], 8): - weight_map_i += [ - torch.nn.functional.interpolate( - weight_map[:, :, wi : wi + 8], - size=(1, H, W), - mode="trilinear", - align_corners=False, - ) - ] - weight_map = torch.cat(weight_map_i, dim=2) - - # Reshape to match THWBD format - weight_map = weight_map.permute(2, 3, 4, 0, 1) # [T, H, W, B, 1] - weight_map = weight_map.view(T * H * W, 1, 1, B, 1) - if self.sequence_parallel: - weight_map = scatter_along_first_dim(weight_map, tp_group) - - else: # BTHWD format - raise NotImplementedError("BTHWD format for weight map is not implemented yet.") + weight_map = control_weight[i] # [B, 1, T, H, W] + # Reshape to match THWBD format + weight_map = weight_map.permute(2, 3, 4, 0, 1) # [T, H, W, B, 1] + weight_map = weight_map.view(T * H * W, 1, 1, B, 1) + + if self.sequence_parallel: + weight_map = scatter_along_first_dim(weight_map, tp_group) + hint_val = control_feat * weight_map * coin_flip * gate if name not in outs: From 47f6e0836c1436963d5a3c4457f36777e03b1c9c Mon Sep 17 00:00:00 2001 From: Qianli Ma Date: Tue, 15 Apr 2025 00:34:11 -0700 Subject: [PATCH 06/10] cleanup --- .../networks/general_dit_ctrl_enc.py | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py b/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py index faeac7e7..c66a4488 100644 --- a/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py +++ b/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py @@ -140,17 +140,12 @@ def encode_hint( hint_B_T_H_W_D, _ = self.prepare_hint_embedded_sequence(hint, fps=fps, padding_mask=padding_mask) - if self.blocks["block0"].x_format == "THWBD": - hint = rearrange(hint_B_T_H_W_D, "B T H W D -> T H W B D") - if self.sequence_parallel: - tp_group = parallel_state.get_tensor_model_parallel_group() - T, H, W, B, D = hint.shape - hint = hint.view(T * H * W, 1, 1, B, -1) - hint = scatter_along_first_dim(hint, tp_group) - elif self.blocks["block0"].x_format == "BTHWD": - hint = hint_B_T_H_W_D - else: - raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + hint = rearrange(hint_B_T_H_W_D, "B T H W D -> T H W B D") + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + T, H, W, B, D = hint.shape + hint = hint.view(T * H * W, 1, 1, B, -1) + hint = scatter_along_first_dim(hint, tp_group) guided_hint = self.input_hint_block(hint) return guided_hint @@ -245,10 +240,9 @@ def forward( else: crossattn_mask = None - if self.blocks["block0"].x_format == "THWBD": - crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") - if crossattn_mask: - crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") outs = {} @@ -379,7 +373,7 @@ def forward( # Reshape to match THWBD format weight_map = weight_map.permute(2, 3, 4, 0, 1) # [T, H, W, B, 1] weight_map = weight_map.view(T * H * W, 1, 1, B, 1) - + if self.sequence_parallel: weight_map = scatter_along_first_dim(weight_map, tp_group) From dbf668347dab50e2f4d324deca71f1c977f0bb2c Mon Sep 17 00:00:00 2001 From: Qianli Ma Date: Tue, 15 Apr 2025 01:27:36 -0700 Subject: [PATCH 07/10] feat: compelete README in examples/ for post/pre-training; update the main README --- README.md | 16 +- .../diffusion/config/base/data.py | 2 +- .../experiment/ctrl_7b_tp_121frames.py | 61 +++-- .../diffusion/datasets/augmentor_provider.py | 2 +- .../datasets/augmentors/merge_datadict.py | 4 +- .../datasets/example_transfer_dataset.py | 10 +- .../diffusion/inference/inference_utils.py | 22 +- .../post-training_cosmos_transfer_7b_edge.md | 213 ---------------- ...process_control_input_data_for_training.md | 108 ++++++++ examples/training_cosmos_transfer_7b.md | 230 ++++++++++++++++++ scripts/convert_ckpt_fsdp_to_tp.py | 14 +- 11 files changed, 416 insertions(+), 266 deletions(-) delete mode 100644 examples/post-training_cosmos_transfer_7b_edge.md create mode 100644 examples/process_control_input_data_for_training.md create mode 100644 examples/training_cosmos_transfer_7b.md diff --git a/README.md b/README.md index 7a2e9b3d..1ba9fbf7 100644 --- a/README.md +++ b/README.md @@ -15,8 +15,8 @@ Cosmos-Transfer1 includes the following: - **ControlNet-based single modality conditional world generation** where a user can generate visual simulation based on one of the following modalities: segmentation video, depth video, edge video, blur video, LiDAR video, or HDMap video. Cosmos-Transfer1 generates a video based on the signal modality conditional input, a user text prompt, and, optionally, an input RGB video frame prompt (which could be from the last video generation result when operating in the autoregressive setting). We will use Cosmos-Transfer1-7B [Modality] to refer to the model operating in this setting. For example, Cosmos-Transfer1-7B [Depth] refers to a depth ControlNet model. - **MultiControlNet-based multimodal conditional world generation** where a user can generate visual simulation based on any combination of segmentation video, depth video, edge video, and blur video (LiDAR video and HDMap in the AV sample) with a spatiotemporal control map to control the stregnth of each modality across space and time. Cosmos-Transfer1 generates a video based on the multimodal conditional inputs, a user text prompt, and, optionally, an input RGB video frame prompt (This could be from the last video generation result when operating in the autoregressive setting.). This is the preferred mode of Cosmos-Transfer. We will refer it as Cosmos-Transfer1-7B. - **4KUpscaler** for upscaling a 720p-resolution video to a 4K-resolution video. -- **Post-training scripts** for helping Physical AI builders post-train pre-trained Cosmos-Transfer1 for their applications [Coming soon]. -- **Pre-training scripts** for helping Physical AI builders train their own Cosmos-Transfer1 models from scratch [Coming soon]. +- **Post-training scripts** for helping Physical AI builders post-train pre-trained Cosmos-Transfer1 for their applications. +- **Pre-training scripts** for helping Physical AI builders train their own Cosmos-Transfer1 models from scratch. ## Example Model Behavior @@ -55,22 +55,14 @@ Please refer to [INSTALL.md](INSTALL.md) for general instructions on environment ### Post-train pre-trained Cosmos-Transfer1 models * Post-train diffusion-based Text2World models using custom datasets [with multi-node support]Coming soon -* Post-train pre-trained Cosmos-Transfer1-7B [Depth]: Coming soon -* Post-train pre-trained Cosmos-Transfer1-7B [Segmentation]: Coming soon -* Post-train pre-trained Cosmos-Transfer1-7B [Edge]: Coming soon -* Post-train pre-trained Cosmos-Transfer1-7B [Vis]: Coming soon -* Post-train pre-trained Cosmos-Transfer1-7B [Keypoint]: Coming soon +* [Post-train pre-trained Cosmos-Transfer1-7B [Depth|Segmentation|Edge|Vis|Keypoint]](examples/training_cosmos_transfer_7b.md) **[with multi-GPU support]** * Post-train pre-trained Cosmos-Transfer1-7B-Sample-AV [LiDAR]: Coming soon * Post-train pre-trained Cosmos-Transfer1-7B-Sample-AV [HDMap]: Coming soon * Post-train pre-trained Cosmos-Transfer1-7B-Sample-AV-Multiview: Coming soon ### Build your own Cosmos-Transfer1 models from scratch -* Pre-train Cosmos-Transfer1-7B [Depth]: Coming soon -* Pre-train Cosmos-Transfer1-7B [Segmentation]: Coming soon -* Pre-train Cosmos-Transfer1-7B [Edge]: Coming soon -* Pre-train Cosmos-Transfer1-7B [Vis]: Coming soon -* Pre-train Cosmos-Transfer1-7B [Keypoint]: Coming soon +* [Pre-train pre-trained Cosmos-Transfer1-7B [Depth|Segmentation|Edge|Vis|Keypoint]](examples/training_cosmos_transfer_7b.md) **[with multi-GPU support]** * Pre-train Cosmos-Transfer1-7B-Sample-AV [LiDAR]: Coming soon * Pre-train Cosmos-Transfer1-7B-Sample-AV [HDMap]: Coming soon diff --git a/cosmos_transfer1/diffusion/config/base/data.py b/cosmos_transfer1/diffusion/config/base/data.py index 283434a4..38e16300 100644 --- a/cosmos_transfer1/diffusion/config/base/data.py +++ b/cosmos_transfer1/diffusion/config/base/data.py @@ -20,7 +20,7 @@ def get_sampler(dataset): def get_example_transfer_dataset(hint_key, is_train=True): dataset = L(ExampleTransferDataset)( - dataset_dir="datasets/example_transfer_training_data", + dataset_dir="datasets/hdvila", chunk_size=256, num_frames=121, resolution="720", diff --git a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py index e2fc161c..356b71a0 100644 --- a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py +++ b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py @@ -18,18 +18,24 @@ The configs are registered under the group "experiment" and can be used in training by passing the experiment name as an argument. Example usage: - - [dryrun, generate and inspect EdgeControl config] torchrun --nproc_per_node=1 -m cosmos_transfer1.diffusion.training.train --dryrun --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3 - - [real run, 8 gpu, train SegControl] torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3 - - [real run, 8 gpu, train DepthControl] torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_depth_block3 + - [dryrun, generate and inspect EdgeControl config]: + torchrun --nproc_per_node=1 -m cosmos_transfer1.diffusion.training.train --dryrun --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain + - [real run, 8 gpu, train SegControl from scratch]: + torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3_pretrain + - [real run, 8 gpu, train SegControl from released checkpoint]: + torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3_posttrain """ from hydra.core.config_store import ConfigStore +import os from cosmos_transfer1.utils.lazy_config import LazyCall as L from cosmos_transfer1.utils.lazy_config import LazyDict from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS_COMB from cosmos_transfer1.diffusion.training.models.model_ctrl import VideoDiffusionModelWithCtrl # this one has training support from cosmos_transfer1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT +from cosmos_transfer1.diffusion.inference.inference_utils import default_model_names +from cosmos_transfer1.checkpoints import BASE_7B_CHECKPOINT_PATH, COSMOS_TRANSFER1_7B_CHECKPOINT cs = ConfigStore.instance() @@ -39,13 +45,18 @@ num_control_blocks = 3 - def make_ctrlnet_config_7b_training( hint_key: str = "control_input_canny", num_control_blocks: int = 3, + pretrain_model_path: str = "" ) -> LazyDict: + if pretrain_model_path == "": + job_name = f"CTRL_7Bv1pt3_lvg_tp_121frames_{hint_key}_block{num_control_blocks}_pretrain" + job_project = "cosmos_transfer1_pretrain" + else: + job_name = f"CTRL_7Bv1pt3_lvg_tp_121frames_{hint_key}_block{num_control_blocks}_posttrain" + job_project = "cosmos_transfer1_posttrain" - # Create the complete configuration in one step config = LazyDict( dict( defaults=[ @@ -64,10 +75,11 @@ def make_ctrlnet_config_7b_training( {"override /data_val": f"example_transfer_val_data_{hint_key}"}, "_self_", ], + # ckpt, config yaml files etc. will be saved under checkpoints//// job=dict( + project=job_project, group="CTRL_7Bv1_lvg", - name=f"CTRL_7Bv1pt3_lvg_tp_121frames_{hint_key}_block{num_control_blocks}", - project="cosmos_transfer1_posttrain", + name=job_name, ), optimizer=dict( lr=2 ** (-14.3), # ~5e-5 @@ -76,10 +88,7 @@ def make_ctrlnet_config_7b_training( eps=1e-10, ), checkpoint=dict( - # Modify load_path as needed if you do post-training (fine-tuning). - # If training from scratch, leave it empty. - # Here we assume post-train our pre-trained VisControl model. - load_path="checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control.pt", + load_path=pretrain_model_path, # Modify load_path as needed if you do post-training (fine-tuning). If training from scratch, leave it empty. broadcast_via_filesystem=True, save_iter=1000, load_training_state=False, @@ -107,8 +116,8 @@ def make_ctrlnet_config_7b_training( 160, ], base_load_from=dict( - load_path="checkpoints/nvidia/Cosmos-Transfer1-7B/base_model.pt", - ), # modify as needed. This is the base model ckpt (that's frozen during training). + load_path=os.path.join(COSMOS_TRANSFER1_7B_CHECKPOINT, "checkpoints_tp", "base_model_mp_*.pt") + ), # modify as needed. This is the TP version of base model ckpt (that's frozen during training). finetune_base_model=False, hint_mask=[True] * len(CTRL_HINT_KEYS_COMB[hint_key]), hint_dropout_rate=0.3, @@ -163,15 +172,31 @@ def make_ctrlnet_config_7b_training( """ Register configurations -The loop below will register all experiments CTRL_7Bv1pt3_lvg_tp_121frames_control_input_{hint_key_name}_block3 for each hint_key_name -and then in training command, simply need to pass the "experiment" arg to override the configs. See the docstring at top of this script -for an example. +The loop below will register ALL experiments CTRL_7Bv1pt3_lvg_tp_121frames_control_input_{hint_key_name}_block3_{pretrain_or_posttrain} for ALL hint_key_name. +Then in training command, simply need to pass the "experiment" arg to override the configs. See the docstring at top of this script for an example. + +# NOTE: To launch real post-training, convert the checkpoints to TP checkpoints first. See scripts/convert_ckpt_fsdp_to_tp.py. """ for key in CTRL_HINT_KEYS_COMB.keys(): - config = make_ctrlnet_config_7b_training(hint_key=key, num_control_blocks=num_control_blocks) + # Register experiments for pretraining from scratch + config = make_ctrlnet_config_7b_training(hint_key=key, num_control_blocks=num_control_blocks, pretrain_model_path="") + cs.store( + group="experiment", + package="_global_", + name=config["job"]["name"], + node=config, + ) + + # Register experiments for post-training from TP checkpoints. + hint_key_short = key.replace("control_input_", "") # "control_input_vis" -> "vis" + base_ckpt_path = default_model_names[hint_key_short] + tp_ckpt_path = os.path.join(os.path.dirname(base_ckpt_path), "checkpoints_tp", os.path.basename(base_ckpt_path)) + config = make_ctrlnet_config_7b_training(hint_key=key, num_control_blocks=num_control_blocks, pretrain_model_path=tp_ckpt_path) + print(tp_ckpt_path, '=======\n\n') + import ipdb; ipdb.set_trace() cs.store( group="experiment", package="_global_", name=config["job"]["name"], node=config, - ) \ No newline at end of file + ) diff --git a/cosmos_transfer1/diffusion/datasets/augmentor_provider.py b/cosmos_transfer1/diffusion/datasets/augmentor_provider.py index 7be95b23..b50a6a8e 100644 --- a/cosmos_transfer1/diffusion/datasets/augmentor_provider.py +++ b/cosmos_transfer1/diffusion/datasets/augmentor_provider.py @@ -85,7 +85,7 @@ def _get_video_ctrlnet_augmentor( resolution: str, blur_config: BlurAugmentorConfig, ): - if hint_key == "control_input_human_kpts": + if hint_key == "control_input_keypoint": add_control_input = L(AddControlInputComb)( input_keys=["", "video"], output_keys=[hint_key], diff --git a/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py b/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py index bfc61897..5703b6ac 100644 --- a/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py +++ b/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py @@ -43,8 +43,8 @@ def __call__(self, data_dict: dict) -> dict: key_dict = data_dict.pop(key) if key == "depth" and "depth" in self.output_keys: data_dict["depth"] = key_dict - if key == "human_kpts" and "human_kpts" in self.output_keys: - data_dict["human_kpts"] = key_dict + if key == "keypoint" and "keypoint" in self.output_keys: + data_dict["keypoint"] = key_dict elif key == "segmentation" and "segmentation" in self.output_keys: data_dict["segmentation"] = key_dict for sub_key in key_dict: diff --git a/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py b/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py index 6ee69eb6..2c47ea53 100644 --- a/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py +++ b/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py @@ -37,12 +37,12 @@ CTRL_AUG_KEYS = { "depth": "depth", "seg": "segmentation", - "human_kpts": "human_kpts", + "keypoint": "keypoint", } # mappings between control types and corresponding sub-folders names in the data folder CTRL_TYPE_INFO = { - "human_kpts": {"folder": "human_kpts", "format": "pickle", "data_dict_key": "human_kpts"}, + "keypoint": {"folder": "keypoint", "format": "pickle", "data_dict_key": "keypoint"}, "depth": {"folder": "depth", "format": "mp4", "data_dict_key": "depth"}, "seg": {"folder": "seg", "format": "pickle", "data_dict_key": "segmentation"}, "edge": {"folder": None}, # Canny edge, computed on-the-fly @@ -190,10 +190,10 @@ def _load_control_data(self, sample): ctrl_data = pickle.load(f) # key should match line 982 at cosmos_transfer1/diffusion/datasets/augmentors/control_input.py data_dict["segmentation"] = ctrl_data - elif self.ctrl_type == "human_kpts": + elif self.ctrl_type == "keypoint": with open(ctrl_path, 'rb') as f: ctrl_data = pickle.load(f) - data_dict["human_kpts"] = ctrl_data + data_dict["keypoint"] = ctrl_data elif self.ctrl_type == "depth": vr = VideoReader(ctrl_path, ctx=cpu(0)) # Ensure the depth video has the same number of frames @@ -303,7 +303,7 @@ def __str__(self): if __name__ == "__main__": dataset = ExampleTransferDataset( - dataset_dir="assets/example_transfer_training_data/", + dataset_dir="assets/hdvila/", hint_key="control_input_seg", chunk_size=1, num_frames=121, diff --git a/cosmos_transfer1/diffusion/inference/inference_utils.py b/cosmos_transfer1/diffusion/inference/inference_utils.py index c8071910..40d1fbd7 100644 --- a/cosmos_transfer1/diffusion/inference/inference_utils.py +++ b/cosmos_transfer1/diffusion/inference/inference_utils.py @@ -68,6 +68,17 @@ "9,16": (704, 1280), } +# Default model names for each control type +default_model_names = { + "vis": VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, + "seg": SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, + "edge": EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, + "depth": DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, + "keypoint": KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, + "upscale": UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH, + "hdmap": HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, + "lidar": LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, +} class _IncompatibleKeys( NamedTuple( @@ -914,17 +925,6 @@ def validate_controlnet_specs(cfg, controlnet_specs) -> Dict[str, Any]: sigma_max = cfg.sigma_max input_video_path = cfg.input_video_path - default_model_names = { - "vis": VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - "seg": SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - "edge": EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - "depth": DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - "keypoint": KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - "upscale": UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH, - "hdmap": HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - "lidar": LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - } - for hint_key, config in controlnet_specs.items(): if hint_key not in valid_hint_keys: raise ValueError(f"Invalid hint_key: {hint_key}. Must be one of {valid_hint_keys}") diff --git a/examples/post-training_cosmos_transfer_7b_edge.md b/examples/post-training_cosmos_transfer_7b_edge.md deleted file mode 100644 index 63bb89ca..00000000 --- a/examples/post-training_cosmos_transfer_7b_edge.md +++ /dev/null @@ -1,213 +0,0 @@ -## Post-training diffusion-based EdgeControl models - -### Model Support Matrix - -We support the following Cosmos Diffusion models for post-training. Review the available models and their compute requirements for post-tuning and inference to determine the best model for your use case. - -| Model Name | Model Status | Compute Requirements for Post-Training | -|----------------------------------------------|------------------|------------------------------------------| -| Cosmos-Transfer1-7B | **Supported** | 8 NVIDIA GPUs* | - -**\*** `H100-80GB` or `A100-80GB` GPUs are recommended. - -### Environment setup - -Please refer to the Post-training section of [INSTALL.md](/INSTALL.md#post-training) for instructions on environment setup. - -### Download Checkpoints - -1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token. Set the access token to 'Read' permission (default is 'Fine-grained'). - -2. Log in to Hugging Face with the access token: - -```bash -huggingface-cli login -``` - -3. Accept the [LlamaGuard-7b terms](https://huggingface.co/meta-llama/LlamaGuard-7b) - -4. Download the Cosmos model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-transfer1-67c9d328196453be6e568d3e): - -```bash -CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_checkpoints.py --output_dir checkpoints/ -``` - -Note that this will require about 300GB of free storage. Not all these checkpoints will be used in every generation. - -5. The downloaded files should be in the following structure: - -``` -checkpoints/ -├── nvidia -│ ├── Cosmos-Transfer1-7B -│ │ ├── base_model.pt -│ │ ├── vis_control.pt -│ │ ├── edge_control.pt -│ │ ├── seg_control.pt -│ │ ├── depth_control.pt -│ │ ├── keypoint_control.pt -│ │ ├── 4kupscaler_control.pt -│ │ ├── config.json -│ │ └── guardrail -│ │ ├── aegis/ -│ │ ├── blocklist/ -│ │ ├── face_blur_filter/ -│ │ └── video_content_safety_filter/ -│ │ -│ ├── Cosmos-Transfer1-7B-Sample-AV/ -│ │ ├── base_model.pt -│ │ ├── hdmap_control.pt -│ │ └── lidar_control.pt -│ │ -│ │── Cosmos-Tokenize1-CV8x8x8-720p -│ │ ├── decoder.jit -│ │ ├── encoder.jit -│ │ ├── autoencoder.jit -│ │ └── mean_std.pt -│ │ -│ └── Cosmos-UpsamplePrompt1-12B-Transfer -│ ├── depth -│ │ ├── consolidated.safetensors -│ │ ├── params.json -│ │ └── tekken.json -│ ├── README.md -│ ├── segmentation -│ │ ├── consolidated.safetensors -│ │ ├── params.json -│ │ └── tekken.json -│ ├── seg_upsampler_example.png -│ └── viscontrol -│ ├── consolidated.safetensors -│ ├── params.json -│ └── tekken.json -│ -├── depth-anything/... -├── facebook/... -├── google-t5/... -└── IDEA-Research/ -``` - -### Examples - -Post-training a Cosmos-Transfer1 model enables you to train the model to generate videos that are more specific to your use case. - -There are 3 steps to post-training: downloading a dataset, preprocessing the data, and post-training the model. - -#### 1. Download a Dataset - -The first step is to download a dataset with videos and captions. - -You must provide a folder containing a collection of videos in **MP4 format**, preferably 720p. These videos should focus on the subject throughout the entire video so that each video chunk contains the subject. - -For example, you can use a subset of [HD-VILA-100M](https://github.com/microsoft/XPretrain/tree/main/hd-vila-100m) dataset for post-training. - -```bash -# Download metadata with video urls and captions -mkdir -p datasets/hdvila -cd datasets/hdvila -wget https://huggingface.co/datasets/TempoFunk/hdvila-100M/resolve/main/hdvila-100M.jsonl -``` - -Run the following command to download the sample videos used for post-training: - -```bash -# Requirements for Youtube video downloads & video clipping -pip install pytubefix ffmpeg -``` - -```bash -# The script will downlaod the original HD-VILA-100M videos, save the corresponding clips, the captions and the metadata. -CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_diffusion_example_data.py --dataset_path datasets/hdvila --N_videos 128 --do_download --do_clip -``` - -#### 2. Preprocessing the Data - -Run the following command to pre-compute T5-XXL embeddings for the video captions used for post-training: - -```bash -# The script will read the captions, save the T5-XXL embeddings in pickle format. -CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/get_t5_embeddings.py --dataset_path datasets/hdvila -``` - -Dataset folder format: -``` -datasets/hdvila/ -├── metas/ -│ ├── *.json -│ ├── *.txt -├── videos/ -│ ├── *.mp4 -├── t5_xxl/ -│ ├── *.pickle -``` - -Training a VisControl or EdgeControl model is self-supervised: we apply blurs and/or compute canny edges of the input videos on-the-fly during training. Therefore, for these two modalities there is no need to prepare the control input videos separately. - -#### 3. Post-train the Model - -Run the following command to execute an example post-training job with the above data. The command below will output the detailed experiment config file at -`checkpoints/cosmos_transfer1_posttrain/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3/config.yaml`. -```bash -export OUTPUT_ROOT=checkpoints # default value -torchrun --nproc_per_node=1 -m cosmos_transfer1.diffusion.training.train --dryrun --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3 -``` - -Removing the `--dryrun` will start a real training job. -checkpoints/cosmos_transfer1/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3/config.yaml -This command will use ``cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py` to register experiments for all `hint_keys` (control modalities). - -Then the model will be post-trained using the above hdvila dataset. -See the function `make_ctrlnet_config_7b_training` defined in `cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py` to understand how the detailed configs of the model, trainer, dataloader etc. are defined. For the data specifically: - -```python -num_frames = 121 -example_video_dataset = L(Dataset)( - dataset_dir="datasets/hdvila", - sequence_interval=1, - num_frames=num_frames, - video_size=(720, 1280), - start_frame_interval=1, -) - -dataloader_train = L(DataLoader)( - dataset=example_video_dataset, - sampler=L(get_sampler)(dataset=example_video_dataset), - batch_size=1, - drop_last=True, -) -... - -config = LazyDict( - dict( - ... - dataloader_train=dataloader_train, - ... - ) -) -... -``` - -The checkpoints will be saved to `${OUTPUT_ROOT}/PROJECT/GROUP/NAME`. -In the above example, `PROJECT` is `cosmos_transfer1_posttrain`, `GROUP` is `CTRL_7Bv1_lvg`, `NAME` is `CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3`. - -See the job config to understand how they are determined. -```python -edgecontrol_7b_example_hdvila = LazyDict( - dict( - ... - job=dict( - project="cosmos_transfer1_posttrain", - group="CTRL_7Bv1_lvg", - name="CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3", - ), - ... - ) -) -``` - -During the training, the checkpoints will be saved in the below structure. -``` -checkpoints/cosmos_transfer1_posttrain/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3/checkpoints/ -├── iter_{NUMBER}_reg_model.pt -├── iter_{NUMBER}_ema_model.pt -``` \ No newline at end of file diff --git a/examples/process_control_input_data_for_training.md b/examples/process_control_input_data_for_training.md new file mode 100644 index 00000000..d7767127 --- /dev/null +++ b/examples/process_control_input_data_for_training.md @@ -0,0 +1,108 @@ +# Processing Control Input Data for Training + +This document provides detailed information about preparing control input data for training different Cosmos-Transfer1 models. + +## DepthControl Training Data Format + +- Requires depth videos in MP4 format +- Must be frame-wise aligned with corresponding RGB videos, and has same [H, W] dimensions as the input videos. +- Place in `depth/` directory + +## SegControl Training Data Format + +The segmentation data is stored in pickle files, one per video. After loading a pickle file, the data structure is as follows: + +```python +[ + { # First detected object + 'phrase': str, # Name/description of the detected object + 'segmentation_mask_rle': { + 'data': bytes, # Run-length encoded binary mask data + 'mask_shape': tuple # Shape of the mask (height, width) + } + }, + { # Second detected object + 'phrase': str, + 'segmentation_mask_rle': { + 'data': bytes, + 'mask_shape': tuple + } + }, + # ... more detected objects +] +``` + +#### Key Components: + +1. **Object Detection**: + - List of dictionaries, one per detected object + - Each detection contains: + - `phrase`: String describing the object + - `segmentation_mask_rle`: Dictionary containing: + - `data`: RLE-encoded binary mask data + - `mask_shape`: Tuple specifying mask dimensions (height, width) + +2. **Mask Creation**: + - Reference implementation in `cosmos_transfer1/auxiliary/sam2/sam2_model.py` + + +## KeypointControl Training Data Format + +For training KeypointControl models, you need to provide a pickle file containing 2D human keypoint annotations for each frame. The pickle file should follow this structure: + +```python +{ + frame_id: [ # List of detected humans in this frame + { # Annotation for one human + 'human-bbox': np.array([x1, y1, x2, y2, confidence], dtype=np.float16), # Normalized coordinates + 'human-bbox-abs': np.array([x1, y1, x2, y2, confidence], dtype=np.float16), # Absolute coordinates + 'body-keypoints': np.array([[x, y, confidence], ...], dtype=np.float16), # Shape: [133, 3], in the COCO-Wholebody format, normalized coordinates + 'body-keypoints-abs': np.array([[x, y, confidence], ...], dtype=np.float16), # Shape: [133, 3], in the COCO-Wholebody format, absolute coordinates + 'hand-keypoints': np.array([[x, y, confidence], ...], dtype=np.float16), # Shape: [42, 3], relative coordinates. It's a duplicate of the [91:133]-th keypoints of the 'body-keypoints' + 'face-bbox': np.array([x1, y1, width, height], dtype=np.float16), # Normalized coordinates of the face bounding boxes of the humans detected + 'face-bbox-abs': np.array([x1, y1, width, height], dtype=np.int16) # Absolute coordinates of the face bounding boxes of the humans detected + }, + # ... more humans in this frame + ], + # ... more frames +} +``` + +### Key Components: + +1. **Frame ID**: + - Key in the dictionary + - Should match the corresponding video frame + +2. **Per-Human Detection**: + - List of dictionaries, one per detected human + - Each detection contains: + - Bounding boxes (normalized and absolute) + - Body keypoints (133 points) + - Hand keypoints (42 points) + - Face bounding box + +3. **Coordinate Systems**: + - Normalized coordinates: Values between 0 and 1 + - Absolute coordinates: Pixel coordinates in the image + - All coordinates follow [x, y] format + +4. **Confidence Scores**: + - Included for each keypoint and bounding box + - Values between 0 and 1 + - Higher values indicate more reliable detections + +### Data Preparation Tips: + +1. **Keypoint Detection**: + - We used `rtmlib` for human keypoint detection and output the COCO-Wholebody keypoint convention. + +2. **File Organization**: + - Name the pickle file to match the video name + - Place in the `keypoint/` directory + - Ensure frame IDs match video frames + +## VisControl and EdgeControl +- These are self-supervised +- No separate data preparation needed +- Control inputs are generated on-the-fly during training. \ No newline at end of file diff --git a/examples/training_cosmos_transfer_7b.md b/examples/training_cosmos_transfer_7b.md new file mode 100644 index 00000000..66fd3ec6 --- /dev/null +++ b/examples/training_cosmos_transfer_7b.md @@ -0,0 +1,230 @@ +## Training Cosmos-Transfer1 Models +In this document, we provide examples and steps to: +- Build your own Cosmos-Transfer1 models, training from scratch; or +- Post-train Cosmos-Transfer1 models from our checkpoint using your data. + +The model is trained separately for each control input type. + + +### Model Support Matrix +We support the following Cosmos-Transfer models for pre-training and post-training. Review the available models and their compute requirements for post-training and inference to determine the best model for your use case. + +| Model Name | Model Status | Compute Requirements for Post-Training | +|------------------------------------------|--------------|----------------------------------------| +| Cosmos-Transfer1-7B [Depth] | **Supported**| 8 NVIDIA GPUs* | +| Cosmos-Transfer1-7B [Segmentation] | **Supported**| 8 NVIDIA GPUs* | +| Cosmos-Transfer1-7B [Edge] | **Supported**| 8 NVIDIA GPUs* | +| Cosmos-Transfer1-7B [Vis] | **Supported**| 8 NVIDIA GPUs* | +| Cosmos-Transfer1pt1-7B [Keypoint] | **Supported**| 8 NVIDIA GPUs* | + +**\*** `H100-80GB` or `A100-80GB` GPUs are recommended. + +### Environment setup + +Please refer to the Post-training section of [INSTALL.md](/INSTALL.md#post-training) for instructions on environment setup. + +### Download Checkpoints + +1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token. Set the access token to 'Read' permission (default is 'Fine-grained'). + +2. Log in to Hugging Face with the access token: + +```bash +huggingface-cli login +``` + +3. Accept the [LlamaGuard-7b terms](https://huggingface.co/meta-llama/LlamaGuard-7b) + +4. Download the Cosmos model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-transfer1-67c9d328196453be6e568d3e). Note that this will require about 300GB of free storage. + +```bash +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_checkpoints.py --output_dir checkpoints/ +``` + +5. The downloaded files should be in the following structure. + +``` +checkpoints/ +├── nvidia +│ ├── Cosmos-Transfer1-7B +│ │ ├── base_model.pt +│ │ ├── vis_control.pt +│ │ ├── edge_control.pt +│ │ ├── seg_control.pt +│ │ ├── depth_control.pt +│ │ ├── keypoint_control.pt +│ │ ├── 4kupscaler_control.pt +│ │ ├── config.json +│ │ └── guardrail +│ │ ├── aegis/ +│ │ ├── blocklist/ +│ │ ├── face_blur_filter/ +│ │ └── video_content_safety_filter/ +│ │ +│ ├── Cosmos-Transfer1-7B-Sample-AV/ +│ │ ├── base_model.pt +│ │ ├── hdmap_control.pt +│ │ └── lidar_control.pt +│ │ +│ │── Cosmos-Tokenize1-CV8x8x8-720p +│ │ ├── decoder.jit +│ │ ├── encoder.jit +│ │ ├── autoencoder.jit +│ │ └── mean_std.pt +│ │ +│ └── Cosmos-UpsamplePrompt1-12B-Transfer +│ ├── depth +│ │ ├── consolidated.safetensors +│ │ ├── params.json +│ │ └── tekken.json +│ ├── README.md +│ ├── segmentation +│ │ ├── consolidated.safetensors +│ │ ├── params.json +│ │ └── tekken.json +│ ├── seg_upsampler_example.png +│ └── viscontrol +│ ├── consolidated.safetensors +│ ├── params.json +│ └── tekken.json +│ +├── depth-anything/... +├── facebook/... +├── google-t5/... +└── IDEA-Research/ +``` + +Checkpoint Requirements: +- Base model (`base_model.pt`) and tokenizer models (under `Cosmos-Tokenize1-CV8x8x8-720p`): Required for all training. +- Control modality-specific model checkpoint (e.g., `seg_control.pt`): Only needed for post-training that specific control. Not needed if training from scratch. +- Other folders such as `depth-anything`, `facebook/sam2-hiera-large` etc.: optional. These are helper modules to process the video data into the respective control modalities such as depth and segmentation. + +### Example +There are 3 steps to train a Cosmos-Transfer1 model: preparing a dataset, prepare checkpoints, and launch training. + +In the example below, we use a subset of [HD-VILA-100M](https://github.com/microsoft/XPretrain/tree/main/hd-vila-100m) dataset to demonstrate the steps for preparing the data and launching training. After preprocessing, your dataset directory should be structured as follows: +``` +datasets/hdvila/ +├── metas/ +│ ├── *.json +│ ├── *.txt +├── videos/ +│ ├── *.mp4 +├── t5_xxl/ +│ ├── *.pickle +├── keypoint/ +│ ├── *.pickle +├── depth/ +│ ├── *.mp4 +├── seg/ +│ ├── *.pickle +└── / + ├── +``` + +File naming must be consistent across modalities. For example, to train a SegControl model with a video named `videos/example1.mp4`, the corresponding annotation files should be: `seg/example1.pickle`. + +Note: Only the folder corresponding to your chosen control input modality is required. For example, if you're training with depth as the control input, only the `depth/` subfolder is needed. + +#### 1. Prepare Videos and Captions + +The first step is to prepare a dataset with videos and captions. You must provide a folder containing a collection of videos in **MP4 format**, preferably 720p. These videos should focus on the subject throughout the entire video so that each video chunk contains the subject. + +Here we use a subset of sample videos from HD-VILA-100M as an example: + +```bash +# Download metadata with video urls and captions +mkdir -p datasets/hdvila +cd datasets/hdvila +wget https://huggingface.co/datasets/TempoFunk/hdvila-100M/resolve/main/hdvila-100M.jsonl +``` + +Run the following command to download the sample videos used for training: + +```bash +# Requirements for Youtube video downloads & video clipping +pip install pytubefix ffmpeg +``` + +```bash +# The script will downlaod the original HD-VILA-100M videos, save the corresponding clips, the captions and the metadata. +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_diffusion_example_data.py --dataset_path datasets/hdvila --N_videos 128 --do_download --do_clip +``` + +#### 2. Computing T5 Text Embeddings +Run the following command to pre-compute T5-XXL embeddings for the video captions used for training: + +```bash +# The script will read the captions, save the T5-XXL embeddings in pickle format. +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/get_t5_embeddings.py --dataset_path datasets/hdvila +``` + +#### 3. Obtaining the Control Input Data +Next, we generate the control input data corresponding to each video. If you already have accurate control input data (e.g., ground truth depth, segmentation masks, or human keypoints), you can skip this step -- just ensure your files are organized in the above structure, and follow the data format as detailed below. + +Here, as an example, we show show how to obtain the control input signals from the input RGB videos. Specifically: + +- DepthControl requires a depth video that is frame-wise aligned with the corresponding RGB video. This can be obtained by, for example, running [DepthAnythingV2](https://github.com/DepthAnything/Depth-Anything-V2) on the input videos. + +- SegControl requires a `.pickle` file in the SAM2 output format containing per-frame segmentation masks. See [Process Control Input Data](process_control_input_data_for_training.md) for detailed format requirements. + +- KeypointControl requires a `.pickle` file containing 2D human keypoint annotations for each frame. See [Process Control Input Data](process_control_input_data_for_training.md) for detailed format requirements. + +For VisControl and EdgeControl models: training is self-supervised. These models get control inputs (e.g., by applying blur or extracting Canny edges) from the input videos on-the-fly during training. Therefore, you do not need to prepare control input data separately for these modalities. + + + + +#### 4. Splitting the Checkpoints to TensorParallel Checkpoints +Due to the large model size, we leverage TensorParallel (TP) to split the model weights across multiple GPUs. We use 8 for the TP size. + +```bash +# Will split the Base model checkpoint into 8 TP checkpoints +python scripts/convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B/base_model.pt +# Example: for VisControl checkpoint splitting for post-train. +python scripts/convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control.pt +``` +This will generate the TP checkpoints under `checkpoints/checkpoints_tp/*_mp_*.pt`, which we load in the training below. + +#### 5. Launch Training +Now we can start training! Run the following command to dry-run an example training job with the above data: +```bash +export OUTPUT_ROOT=checkpoints # default value + +torchrun --nproc_per_node=1 -m cosmos_transfer1.diffusion.training.train --dryrun --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain +``` +Explanation of the command: + +- The trainer and the passed (master) config script will, in the background, load the detailed experiment configurations defined in `cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py`, and register the experiments configurations for all `hint_keys` (control modalities), covering both pretrain and post-train. We use [Hydra](https://hydra.cc/docs/intro/) for advanced configuration composition and overriding. + +- The `CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain` corresponds to an experiment name registered in `ctrl_7b_tp_121frames.py`. By specifiying this name, all the detailed config will be loaded. The full configuration is also written to `checkpoints/cosmos_transfer1_pretrain/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain/config.yaml`. + +- To customize your training, see `cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py` to understand how the detailed configs of the model, trainer, dataloader etc. are defined, and edit as needed. + +- Removing the `--dryrun` will start a real training job. + +- Change the `experiment` value will decide which control modality model is trained, and whether it's pretrain or post-train. For example, replacing the experiment name in the command with `CTRL_7Bv1pt3_lvg_tp_121frames_control_input_depth_block3_posttrain` will post-train the DepthControl model from the downloaded checkpoint instead. + +- The checkpoints will be saved to `${OUTPUT_ROOT}/PROJECT/GROUP/NAME`. See the job config to understand how they are determined: + +```python +# in cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py +config = LazyDict( + dict( + ... + job=dict( + project="cosmos_transfer1_pretrain", + group="CTRL_7Bv1_lvg", + name="CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain", + ), + ... + ) +) +``` + +During the training, the checkpoints will be saved in the below structure. +``` +checkpoints/cosmos_transfer1_pretrain/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain/checkpoints/ +├── iter_{NUMBER}_reg_model.pt +├── iter_{NUMBER}_ema_model.pt +``` diff --git a/scripts/convert_ckpt_fsdp_to_tp.py b/scripts/convert_ckpt_fsdp_to_tp.py index 50492cf3..9af82604 100644 --- a/scripts/convert_ckpt_fsdp_to_tp.py +++ b/scripts/convert_ckpt_fsdp_to_tp.py @@ -110,17 +110,25 @@ def convert_fsdp_to_tp(path_in: str, path_out: str) -> None: python convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control.pt This will save the Tensor Parallel (TP) checkpoints as 8 files in the same directory: - checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_tp_mp_0.pt + checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_mp_0.pt ... - checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_tp_mp_7.pt + checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_mp_7.pt ''' if len(sys.argv) != 2: print("Usage: python convert_ckpt_fsdp_to_tp.py ") print("Example: python convert_ckpt_fsdp_to_tp.py checkpoints/model.pt") sys.exit(1) + checkpoint_path = sys.argv[1] - out_tp_checkpoint_path = os.path.basename(checkpoint_path).replace(".pt", "") + + # Create checkpoints_tp directory in the same parent directory as the input checkpoint + input_dir = os.path.dirname(checkpoint_path) + tp_ckpt_dir = os.path.join(input_dir, 'checkpoints_tp') + os.makedirs(tp_ckpt_dir, exist_ok=True) + + # Use the same basename as input but in the checkpoints_tp directory + out_tp_checkpoint_path = os.path.join(tp_ckpt_dir, os.path.basename(checkpoint_path).replace(".pt", "")) try: convert_fsdp_to_tp(checkpoint_path, out_tp_checkpoint_path) print("Conversion completed successfully!") From d02c81f6843d2c1ed8ba45174e8d79257c02c475 Mon Sep 17 00:00:00 2001 From: Qianli Ma Date: Tue, 15 Apr 2025 17:22:47 -0700 Subject: [PATCH 08/10] fix: multiple minor fixes on example dataset --- .../diffusion/config/base/data.py | 4 +- .../experiment/ctrl_7b_tp_121frames.py | 9 +- .../diffusion/config/transfer/conditioner.py | 7 + .../diffusion/datasets/augmentor_provider.py | 18 ++- .../datasets/augmentors/control_input.py | 2 +- .../datasets/augmentors/merge_datadict.py | 2 +- .../diffusion/datasets/dataset_utils.py | 8 +- .../datasets/example_transfer_dataset.py | 120 +++++++++--------- cosmos_transfer1/diffusion/training/train.py | 2 +- cosmos_transfer1/utils/config.py | 2 + examples/training_cosmos_transfer_7b.md | 10 +- scripts/convert_ckpt_fsdp_to_tp.py | 14 +- 12 files changed, 107 insertions(+), 91 deletions(-) diff --git a/cosmos_transfer1/diffusion/config/base/data.py b/cosmos_transfer1/diffusion/config/base/data.py index 38e16300..d16fc0ef 100644 --- a/cosmos_transfer1/diffusion/config/base/data.py +++ b/cosmos_transfer1/diffusion/config/base/data.py @@ -21,7 +21,6 @@ def get_sampler(dataset): def get_example_transfer_dataset(hint_key, is_train=True): dataset = L(ExampleTransferDataset)( dataset_dir="datasets/hdvila", - chunk_size=256, num_frames=121, resolution="720", hint_key=hint_key, @@ -33,6 +32,9 @@ def get_example_transfer_dataset(hint_key, is_train=True): sampler=L(get_sampler)(dataset=dataset), batch_size=1, drop_last=True, + num_workers=8, # adjust as needed + prefetch_factor=2, # adjust as needed + pin_memory=True, ) diff --git a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py index 356b71a0..2b65aa59 100644 --- a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py +++ b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py @@ -116,7 +116,7 @@ def make_ctrlnet_config_7b_training( 160, ], base_load_from=dict( - load_path=os.path.join(COSMOS_TRANSFER1_7B_CHECKPOINT, "checkpoints_tp", "base_model_mp_*.pt") + load_path=os.path.join("checkpoints", COSMOS_TRANSFER1_7B_CHECKPOINT, "checkpoints_tp", "base_model_model_mp_*.pt") ), # modify as needed. This is the TP version of base model ckpt (that's frozen during training). finetune_base_model=False, hint_mask=[True] * len(CTRL_HINT_KEYS_COMB[hint_key]), @@ -189,11 +189,10 @@ def make_ctrlnet_config_7b_training( # Register experiments for post-training from TP checkpoints. hint_key_short = key.replace("control_input_", "") # "control_input_vis" -> "vis" - base_ckpt_path = default_model_names[hint_key_short] - tp_ckpt_path = os.path.join(os.path.dirname(base_ckpt_path), "checkpoints_tp", os.path.basename(base_ckpt_path)) + pretrain_ckpt_path = default_model_names[hint_key_short] + # note: The TP ckpt path are specified as .pt to the script, but actually the _model_mp_*.pt files will be loaded. + tp_ckpt_path = os.path.join("checkpoints", os.path.dirname(pretrain_ckpt_path), "checkpoints_tp", os.path.basename(pretrain_ckpt_path)) config = make_ctrlnet_config_7b_training(hint_key=key, num_control_blocks=num_control_blocks, pretrain_model_path=tp_ckpt_path) - print(tp_ckpt_path, '=======\n\n') - import ipdb; ipdb.set_trace() cs.store( group="experiment", package="_global_", diff --git a/cosmos_transfer1/diffusion/config/transfer/conditioner.py b/cosmos_transfer1/diffusion/config/transfer/conditioner.py index 5c130f0d..0b4ce0ce 100644 --- a/cosmos_transfer1/diffusion/config/transfer/conditioner.py +++ b/cosmos_transfer1/diffusion/config/transfer/conditioner.py @@ -65,6 +65,13 @@ "control_input_upscale", ] +# for data loading. Defining corresponding sub-folders in the data folder +CTRL_AUG_KEYS = { + "depth": "depth", + "seg": "segmentation", + "keypoint": "keypoint", +} + BaseVideoConditionerWithCtrlConfig: LazyDict = L(VideoConditionerWithCtrl)( text=TextConfig(), diff --git a/cosmos_transfer1/diffusion/datasets/augmentor_provider.py b/cosmos_transfer1/diffusion/datasets/augmentor_provider.py index b50a6a8e..fdc81239 100644 --- a/cosmos_transfer1/diffusion/datasets/augmentor_provider.py +++ b/cosmos_transfer1/diffusion/datasets/augmentor_provider.py @@ -28,8 +28,8 @@ CTRL_HINT_KEYS, CTRL_HINT_KEYS_COMB, ) -from cosmos_transfer1.diffusion.datasets.example_transfer_dataset import CTRL_AUG_KEYS -from cosmos_transfer1.diffusion.config.transfer.blurs import BlurAugmentorConfig +from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_AUG_KEYS +from cosmos_transfer1.diffusion.config.transfer.blurs import BlurAugmentorConfig, random_blur_config AUGMENTOR_OPTIONS = {} @@ -46,7 +46,6 @@ def decorator(func): @augmentor_register("video_basic_augmentor") def get_video_augmentor( resolution: str, - append_fps_frames: str = False, blur_config=None, ): return { @@ -56,7 +55,6 @@ def get_video_augmentor( "video", "fps", "num_frames", - "chunk_index", "frame_start", "frame_end", "orig_num_frames", @@ -83,7 +81,7 @@ def get_video_augmentor( def get_video_ctrlnet_augmentor(hint_key, use_random=True): def _get_video_ctrlnet_augmentor( resolution: str, - blur_config: BlurAugmentorConfig, + blur_config: BlurAugmentorConfig = random_blur_config, ): if hint_key == "control_input_keypoint": add_control_input = L(AddControlInputComb)( @@ -118,7 +116,6 @@ def _get_video_ctrlnet_augmentor( "video", "fps", "num_frames", - "chunk_index", "frame_start", "frame_end", "orig_num_frames", @@ -127,11 +124,12 @@ def _get_video_ctrlnet_augmentor( if key in hint_key: input_keys.append(value) output_keys.append(value) + augmentation = { - "merge_datadict": L(DataDictMerger)( - input_keys=input_keys, - output_keys=output_keys, - ), + # "merge_datadict": L(DataDictMerger)( + # input_keys=input_keys, + # output_keys=output_keys, + # ), # this addes the control input tensor to the data dict "add_control_input": add_control_input, # this resizes both the video and the control input to the model's required input size diff --git a/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py b/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py index 2dbdee89..4f279c49 100644 --- a/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py +++ b/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py @@ -20,7 +20,7 @@ import cv2 import matplotlib.colors as mcolors import numpy as np -import pycocotools +import pycocotools.mask import torch import torchvision.transforms.functional as transforms_F diff --git a/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py b/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py index 5703b6ac..8c811bbb 100644 --- a/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py +++ b/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py @@ -36,7 +36,7 @@ def __call__(self, data_dict: dict) -> dict: for key in self.input_keys: if key not in data_dict: log.warning( - f"DataDictMerger dataloader error: missing {key}, {data_dict['__url__']}, {data_dict['__key__']}", + f"DataDictMerger dataloader error: missing {key}; data_dict keys: {data_dict.keys()}", rank0_only=False, ) return None diff --git a/cosmos_transfer1/diffusion/datasets/dataset_utils.py b/cosmos_transfer1/diffusion/datasets/dataset_utils.py index 5bd2c1d1..78177165 100644 --- a/cosmos_transfer1/diffusion/datasets/dataset_utils.py +++ b/cosmos_transfer1/diffusion/datasets/dataset_utils.py @@ -53,12 +53,8 @@ def obtain_augmentation_size(data_dict: dict, augmentor_cfg: dict) -> Union[int, Returns: aug_size (int): Size of augmentation """ - if "__url__" in data_dict and "aspect_ratio" in data_dict["__url__"].meta.opts: - aspect_ratio = data_dict["__url__"].meta.opts["aspect_ratio"] - aug_size = augmentor_cfg["size"][aspect_ratio] - else: # Non-webdataset format - aspect_ratio = data_dict["aspect_ratio"] - aug_size = augmentor_cfg["size"][aspect_ratio] + aspect_ratio = data_dict["aspect_ratio"] + aug_size = augmentor_cfg["size"][aspect_ratio] return aug_size diff --git a/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py b/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py index 2c47ea53..4fb52ac7 100644 --- a/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py +++ b/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py @@ -32,13 +32,10 @@ from cosmos_transfer1.diffusion.datasets.augmentor_provider import AUGMENTOR_OPTIONS from cosmos_transfer1.diffusion.datasets.augmentors.control_input import VIDEO_RES_SIZE_INFO +from cosmos_transfer1.diffusion.inference.inference_utils import detect_aspect_ratio +from cosmos_transfer1.utils.lazy_config import instantiate -CTRL_AUG_KEYS = { - "depth": "depth", - "seg": "segmentation", - "keypoint": "keypoint", -} # mappings between control types and corresponding sub-folders names in the data folder CTRL_TYPE_INFO = { @@ -55,22 +52,17 @@ class ExampleTransferDataset(Dataset): def __init__( self, dataset_dir, - chunk_size, num_frames, resolution, - start_frame_interval=1, hint_key="control_input_vis", - # augmentor_name="video_basic_augmentor", is_train=True ): """Dataset class for loading video-text-to-video generation data with control inputs. Args: dataset_dir (str): Base path to the dataset directory - chunk_size (int): Interval between sampled frames in a sequence. - num_frames (int): Number of frames to load per sequence + num_frames (int): Number of consecutive frames to load per sequence resolution (str): resolution of the target video size - start_frame_interval (int): Interval for starting frames hint_key (str): The hint key for loading the correct control input data modality is_train (bool): Whether this is for training @@ -78,8 +70,6 @@ def __init__( """ super().__init__() self.dataset_dir = dataset_dir - self.start_frame_interval = start_frame_interval - self.chunk_size = chunk_size self.sequence_length = num_frames self.is_train = is_train self.resolution = resolution @@ -106,10 +96,10 @@ def __init__( augmentor_name = f"video_ctrlnet_augmentor_{hint_key}" # The augmentor will process the 'raw' control input data to the tensor, # add it to the data dict, and resize both the video and the control input to the model's required input size - self.augmentor = AUGMENTOR_OPTIONS[augmentor_name]( - resolution=resolution, - append_fps_frames=False - ) + # self.augmentor = AUGMENTOR_OPTIONS[augmentor_name](resolution=resolution) + augmentor_cfg = AUGMENTOR_OPTIONS[augmentor_name](resolution=resolution) + # Instantiate the augmentor configuration to get the actual augmentor objects (TODO (qianlim) remove instantiate here?) + self.augmentor = {k: instantiate(v) for k, v in augmentor_cfg.items()} def _init_samples(self, video_paths): samples = [] @@ -123,9 +113,13 @@ def _init_samples(self, video_paths): return samples def _load_and_process_video_path(self, video_path): + """ + Function for sampling a chunk of self.sequence_length frames from the video. + Current version: randomly sample a chunk of self.sequence_length frames from the video each time. + Modify this for a different sampling strategy if needed. + """ vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) n_frames = len(vr) - # Check if all required control files exist ctrl_files_exist = True video_name = os.path.basename(video_path).replace(".mp4", "") @@ -147,36 +141,37 @@ def _load_and_process_video_path(self, video_path): if not ctrl_files_exist: return samples - for frame_i in range(0, n_frames, self.start_frame_interval): - sample = dict() - sample["video_path"] = video_path - sample["t5_embedding_path"] = os.path.join( + # Calculate the maximum possible starting frame that allows for a full sequence + max_start_idx = n_frames - self.sequence_length + + if max_start_idx < 0: # Video is too short + return samples + + # Randomly select a starting frame + start_frame = np.random.randint(0, max_start_idx + 1) + + sample = dict( + video_path=video_path, + t5_embedding_path=os.path.join( self.t5_dir, os.path.basename(video_path).replace(".mp4", ".pickle"), ) + ) - if self.ctrl_data_pth_config["folder"] is not None: - sample["ctrl_path"] = os.path.join( - self.dataset_dir, - self.ctrl_data_pth_config["folder"], - f"{video_name}.{self.ctrl_data_pth_config['format']}" - ) - else: - sample["ctrl_path"] = None - - sample["frame_ids"] = [] - sample["chunk_index"] = -1 - curr_frame_i = frame_i - while True: - if curr_frame_i > (n_frames - 1): - break - sample["frame_ids"].append(curr_frame_i) - if len(sample["frame_ids"]) == self.sequence_length: - break - curr_frame_i += self.chunk_size - if len(sample["frame_ids"]) == self.sequence_length: - sample["chunk_index"] += 1 - samples.append(sample) + if self.ctrl_data_pth_config["folder"] is not None: + sample["ctrl_path"] = os.path.join( + self.dataset_dir, + self.ctrl_data_pth_config["folder"], + f"{video_name}.{self.ctrl_data_pth_config['format']}" + ) + else: + sample["ctrl_path"] = None + + # Generate consecutive frame IDs + sample["frame_ids"] = list(range(start_frame, start_frame + self.sequence_length)) + # sample["chunk_index"] = 0 + samples.append(sample) + return samples def _load_control_data(self, sample): @@ -207,7 +202,7 @@ def _load_control_data(self, sample): "video": depth_frames, "frame_start": frame_ids[0], "frame_end": frame_ids[-1], - "chunk_index": sample["chunk_index"] + # "chunk_index": sample["chunk_index"] } except Exception as e: @@ -232,8 +227,6 @@ def _get_frames(self, video_path, frame_ids): frames, fps = self._load_video(video_path, frame_ids) frames = frames.astype(np.uint8) frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w) - frames = self.preprocess(frames) - frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) return frames, fps def __getitem__(self, index): @@ -248,6 +241,9 @@ def __getitem__(self, index): video, fps = self._get_frames(video_path, frame_ids) video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + aspect_ratio = detect_aspect_ratio((video.shape[3], video.shape[2])) # the func expects (w, h) + data["aspect_ratio"] = aspect_ratio + # Basic data data["video"] = video data["video_name"] = { @@ -265,8 +261,8 @@ def __getitem__(self, index): # Add metadata data["fps"] = fps data["frame_start"] = frame_ids[0] - data["frame_end"] = frame_ids[-1] - data["chunk_index"] = sample["chunk_index"] + data["frame_end"] = frame_ids[-1] + 1 + # data["chunk_index"] = sample["chunk_index"] data["image_size"] = torch.tensor([704, 1280, 704, 1280]).cuda() data["num_frames"] = self.sequence_length data["padding_mask"] = torch.zeros(1, 704, 1280).cuda() @@ -302,28 +298,38 @@ def __str__(self): if __name__ == "__main__": + ''' + Sanity check for the dataset. + ''' + control_input_key = "control_input_edge" + visualize_control_input = False + dataset = ExampleTransferDataset( - dataset_dir="assets/hdvila/", - hint_key="control_input_seg", - chunk_size=1, + dataset_dir="datasets/hdvila/", + hint_key=control_input_key, num_frames=121, resolution="720", - # augmentor_name="video_basic_augmentor", is_train=True ) - - indices = [0, 13, 200, -1] + print("finished init dataset") + indices = [0, 12, 100, -1] for idx in indices: data = dataset[idx] print( ( f"{idx=} " + f"{data['frame_start']=}\n" + f"{data['frame_end']=}\n" f"{data['video'].sum()=}\n" f"{data['video'].shape=}\n" - f"{data['depth']['video'].sum()=}\n" - f"{data['depth']['video'].shape=}\n" + f"{data[control_input_key].shape}={data[control_input_key].shape}\n" # should match the video shape f"{data['video_name']=}\n" f"{data['t5_text_embeddings'].shape=}\n" "---" ) ) + if visualize_control_input: + import imageio + control_input_tensor = data[control_input_key].permute(1, 2, 3, 0).cpu().numpy() + video_name = "control_input_edge.mp4" + imageio.mimsave(video_name, control_input_tensor, fps=24) diff --git a/cosmos_transfer1/diffusion/training/train.py b/cosmos_transfer1/diffusion/training/train.py index ebf71be7..569675ca 100644 --- a/cosmos_transfer1/diffusion/training/train.py +++ b/cosmos_transfer1/diffusion/training/train.py @@ -59,7 +59,7 @@ def destroy_distributed(): def launch(config: Config, args: argparse.Namespace) -> None: # Check that the config is valid config.validate() - if config.trainer.timestamp_seed: # TODO (qianlim): check if this is set in the config yaml + if config.trainer.timestamp_seed: # Get the current time in microseconds current_time = int(time.time() * 1e6) # Combine the current time with worker_id to ensure different seeds across workers diff --git a/cosmos_transfer1/utils/config.py b/cosmos_transfer1/utils/config.py index e91b9bf7..8f257354 100644 --- a/cosmos_transfer1/utils/config.py +++ b/cosmos_transfer1/utils/config.py @@ -272,6 +272,8 @@ class TrainerConfig: memory_format: torch.memory_format = torch.preserve_format # Gradient accumulation (update step every N iteration). grad_accum_iter: int = 1 + # Whether to use the timestamp as the seed. Needed to ensure real randomness in loading data. + timestamp_seed: bool = True # # Profiling config # profiling: Profiling = attrs.field(factory=Profiling) diff --git a/examples/training_cosmos_transfer_7b.md b/examples/training_cosmos_transfer_7b.md index 66fd3ec6..b54c77f5 100644 --- a/examples/training_cosmos_transfer_7b.md +++ b/examples/training_cosmos_transfer_7b.md @@ -1,4 +1,4 @@ -## Training Cosmos-Transfer1 Models +# Training Cosmos-Transfer1 Models In this document, we provide examples and steps to: - Build your own Cosmos-Transfer1 models, training from scratch; or - Post-train Cosmos-Transfer1 models from our checkpoint using your data. @@ -160,7 +160,7 @@ CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/get_t5_embeddings.py -- ``` #### 3. Obtaining the Control Input Data -Next, we generate the control input data corresponding to each video. If you already have accurate control input data (e.g., ground truth depth, segmentation masks, or human keypoints), you can skip this step -- just ensure your files are organized in the above structure, and follow the data format as detailed below. +Next, we generate the control input data corresponding to each video. If you already have accurate control input data (e.g., ground truth depth, segmentation masks, or human keypoints), you can skip this step -- just ensure your files are organized in the above structure, and follow the data format as detailed in [Process Control Input Data](process_control_input_data_for_training.md). Here, as an example, we show show how to obtain the control input signals from the input RGB videos. Specifically: @@ -180,9 +180,9 @@ Due to the large model size, we leverage TensorParallel (TP) to split the model ```bash # Will split the Base model checkpoint into 8 TP checkpoints -python scripts/convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B/base_model.pt -# Example: for VisControl checkpoint splitting for post-train. -python scripts/convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control.pt +PYTHONPATH=. python scripts/convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B/base_model.pt +# Example: for EdgeControl checkpoint splitting for post-train. +PYTHONPATH=. python scripts/convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B/edge_control.pt ``` This will generate the TP checkpoints under `checkpoints/checkpoints_tp/*_mp_*.pt`, which we load in the training below. diff --git a/scripts/convert_ckpt_fsdp_to_tp.py b/scripts/convert_ckpt_fsdp_to_tp.py index 9af82604..901d0fa5 100644 --- a/scripts/convert_ckpt_fsdp_to_tp.py +++ b/scripts/convert_ckpt_fsdp_to_tp.py @@ -23,6 +23,7 @@ from collections import OrderedDict from typing import Dict, Any, List +from cosmos_transfer1.utils import log from cosmos_transfer1.utils.easy_io import easy_io @@ -54,7 +55,7 @@ def native_to_tp(reg_state_dict: Dict[str, Any], tp_size: int) -> List[OrderedDi A list of OrderedDicts, each representing a tensor parallel partition. """ tp_state_dict = [OrderedDict() for _ in range(tp_size)] - + log.info("Converting to TP checkpoint..") for key, value in reg_state_dict.items(): if key.endswith("_extra_state"): continue @@ -87,9 +88,11 @@ def convert_fsdp_to_tp(path_in: str, path_out: str) -> None: RuntimeError: For other conversion errors """ try: + log.info(f"Loading checkpoint from {path_in}..") native_ckpt = torch.load( path_in, map_location=torch.device("cpu"), + weights_only=False, # Load to CPU first; weights_only=False required for newer PyTorch versions ) state_dicts = native_to_tp(native_ckpt, TP_SIZE) except FileNotFoundError: @@ -97,9 +100,12 @@ def convert_fsdp_to_tp(path_in: str, path_out: str) -> None: except Exception as e: raise RuntimeError(f"Error loading checkpoint: {str(e)}") + log.info("Saving TP checkpoints..") + # Add a dummy grad_scaler and iteration to the checkpoint. Required by the training script. + easy_io.dump({'grad_scaler': {}, 'iteration': 0}, f"{path_out}.pt") for i in tqdm(range(TP_SIZE)): state_dict = {"model": state_dicts[i]} - easy_io.dump(state_dict, f"{path_out}_mp_{i}.pt") + easy_io.dump(state_dict, f"{path_out}_model_mp_{i}.pt") if __name__ == "__main__": @@ -110,9 +116,9 @@ def convert_fsdp_to_tp(path_in: str, path_out: str) -> None: python convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control.pt This will save the Tensor Parallel (TP) checkpoints as 8 files in the same directory: - checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_mp_0.pt + checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_model_mp_0.pt ... - checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_mp_7.pt + checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_model_mp_7.pt ''' if len(sys.argv) != 2: print("Usage: python convert_ckpt_fsdp_to_tp.py ") From 8d10477c91e2c89fdad1aed8baa0048ae00685f0 Mon Sep 17 00:00:00 2001 From: Qianli Ma Date: Tue, 15 Apr 2025 22:47:10 -0700 Subject: [PATCH 09/10] fix: multiple minor fixes + improve example dataset performance --- .../video_content_safety_filter/model.py | 2 +- .../checkpointer/ema_fsdp_checkpointer.py | 2 +- cosmos_transfer1/diffusion/config/base/net.py | 2 - .../experiment/ctrl_7b_tp_121frames.py | 6 +- .../config/training/registry_extra.py | 94 +- .../diffusion/config/training/tokenizer.py | 68 + .../diffusion/config/transfer/model.py | 2 +- .../datasets/example_transfer_dataset.py | 257 ++-- .../diffusion/modules/res_sampler.py | 2 +- .../diffusion/networks/general_dit.py | 594 ++------- .../networks/general_dit_ctrl_enc.py | 92 +- .../networks/general_dit_video_conditioned.py | 136 +- .../diffusion/training/models/model_image.py | 2 +- .../diffusion/training/modules/blocks.py | 1118 +++++++++++++++++ .../training/modules/pretrained_vae.py | 738 +++++++++++ .../diffusion/training/networks/__init__.py | 0 .../training/networks/general_dit.py | 1029 +++++++++++++++ .../training/networks/general_dit_ctrl_enc.py | 402 ++++++ .../networks/general_dit_video_conditioned.py | 259 ++++ cosmos_transfer1/utils/config.py | 86 +- cosmos_transfer1/utils/ddp_config.py | 106 ++ cosmos_transfer1/utils/distributed.py | 57 +- cosmos_transfer1/utils/log.py | 14 + cosmos_transfer1/utils/misc.py | 74 ++ scripts/convert_ckpt_fsdp_to_tp.py | 6 +- 25 files changed, 4179 insertions(+), 969 deletions(-) create mode 100644 cosmos_transfer1/diffusion/config/training/tokenizer.py create mode 100644 cosmos_transfer1/diffusion/training/modules/blocks.py create mode 100644 cosmos_transfer1/diffusion/training/modules/pretrained_vae.py create mode 100644 cosmos_transfer1/diffusion/training/networks/__init__.py create mode 100644 cosmos_transfer1/diffusion/training/networks/general_dit.py create mode 100644 cosmos_transfer1/diffusion/training/networks/general_dit_ctrl_enc.py create mode 100644 cosmos_transfer1/diffusion/training/networks/general_dit_video_conditioned.py create mode 100644 cosmos_transfer1/utils/ddp_config.py diff --git a/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/model.py b/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/model.py index a5260039..2c9cf962 100644 --- a/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/model.py +++ b/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/model.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn -from cosmos_transfer1.utils.config import make_freezable +from cosmos_transfer1.utils.ddp_config import make_freezable @make_freezable diff --git a/cosmos_transfer1/checkpointer/ema_fsdp_checkpointer.py b/cosmos_transfer1/checkpointer/ema_fsdp_checkpointer.py index 4553ef5c..e74d8e8f 100644 --- a/cosmos_transfer1/checkpointer/ema_fsdp_checkpointer.py +++ b/cosmos_transfer1/checkpointer/ema_fsdp_checkpointer.py @@ -19,7 +19,7 @@ from cosmos_transfer1.utils import log from cosmos_transfer1.utils.config import CheckpointConfig as BaseCheckpointConfig -from cosmos_transfer1.utils.config import make_freezable +from cosmos_transfer1.utils.ddp_config import make_freezable from cosmos_transfer1.checkpointer.fsdp_checkpointer import FSDPCheckpointer as BaseFSDPCheckpointer diff --git a/cosmos_transfer1/diffusion/config/base/net.py b/cosmos_transfer1/diffusion/config/base/net.py index d7a6eab0..43f8e89b 100644 --- a/cosmos_transfer1/diffusion/config/base/net.py +++ b/cosmos_transfer1/diffusion/config/base/net.py @@ -34,9 +34,7 @@ pos_emb_learnable=False, pos_emb_interpolation="crop", block_x_format="THWBD", - additional_timestamp_channels=None, affline_emb_norm=True, use_adaln_lora=True, adaln_lora_dim=256, - legacy_patch_emb=False, ) \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py index 2b65aa59..9d264ed2 100644 --- a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py +++ b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py @@ -33,9 +33,9 @@ from cosmos_transfer1.utils.lazy_config import LazyDict from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS_COMB from cosmos_transfer1.diffusion.training.models.model_ctrl import VideoDiffusionModelWithCtrl # this one has training support -from cosmos_transfer1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT +from cosmos_transfer1.diffusion.training.networks.general_dit_video_conditioned import VideoExtendGeneralDIT from cosmos_transfer1.diffusion.inference.inference_utils import default_model_names -from cosmos_transfer1.checkpoints import BASE_7B_CHECKPOINT_PATH, COSMOS_TRANSFER1_7B_CHECKPOINT +from cosmos_transfer1.checkpoints import COSMOS_TRANSFER1_7B_CHECKPOINT cs = ConfigStore.instance() @@ -116,7 +116,7 @@ def make_ctrlnet_config_7b_training( 160, ], base_load_from=dict( - load_path=os.path.join("checkpoints", COSMOS_TRANSFER1_7B_CHECKPOINT, "checkpoints_tp", "base_model_model_mp_*.pt") + load_path=os.path.join("checkpoints", COSMOS_TRANSFER1_7B_CHECKPOINT, "checkpoints_tp", "base_model_mp_*.pt") ), # modify as needed. This is the TP version of base model ckpt (that's frozen during training). finetune_base_model=False, hint_mask=[True] * len(CTRL_HINT_KEYS_COMB[hint_key]), diff --git a/cosmos_transfer1/diffusion/config/training/registry_extra.py b/cosmos_transfer1/diffusion/config/training/registry_extra.py index 1e83fb57..0ddc70a4 100644 --- a/cosmos_transfer1/diffusion/config/training/registry_extra.py +++ b/cosmos_transfer1/diffusion/config/training/registry_extra.py @@ -19,24 +19,100 @@ from hydra.core.config_store import ConfigStore -import cosmos_transfer1.diffusion.config.registry as base_registry +from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS, BaseVideoConditionerWithCtrlConfig, VideoConditionerFpsSizePaddingWithCtrlConfig import cosmos_transfer1.diffusion.config.training.registry as base_training_registry -from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS +from cosmos_transfer1.diffusion.config.registry import register_conditioner +from cosmos_transfer1.diffusion.config.base.data import register_data_ctrlnet +from cosmos_transfer1.diffusion.training.networks.general_dit_ctrl_enc import GeneralDITEncoder +from cosmos_transfer1.diffusion.training.networks.general_dit import GeneralDIT +from cosmos_transfer1.diffusion.config.training.tokenizer import get_cosmos_diffusion_tokenizer_comp8x8x8 +from cosmos_transfer1.diffusion.config.registry import register_tokenizer +from cosmos_transfer1.utils.lazy_config import LazyCall as L +from cosmos_transfer1.utils.lazy_config import LazyDict +import copy + +FADITV2ConfigTrain: LazyDict = L(GeneralDIT)( + max_img_h=240, + max_img_w=240, + max_frames=128, + in_channels=16, + out_channels=16, + patch_spatial=2, + patch_temporal=1, + model_channels=4096, + block_config="FA-CA-MLP", + num_blocks=28, + num_heads=32, + concat_padding_mask=True, + pos_emb_cls="rope3d", + pos_emb_learnable=False, + pos_emb_interpolation="crop", + block_x_format="THWBD", + additional_timestamp_channels=None, + affline_emb_norm=True, + use_adaln_lora=True, + adaln_lora_dim=256, + legacy_patch_emb=False, +) + +num_blocks = FADITV2ConfigTrain["num_blocks"] +FADITV2EncoderConfigTrain = copy.deepcopy(FADITV2ConfigTrain) +FADITV2EncoderConfigTrain["_target_"] = GeneralDITEncoder +FADITV2EncoderConfigTrain["layer_mask"] = [True if i > num_blocks // 2 else False for i in range(num_blocks)] -from cosmos_transfer1.diffusion.config.transfer.registry import register_experiment_ctrlnet -from cosmos_transfer1.diffusion.config.base.data import register_data_ctrlnet + +def register_net_train(cs): + cs.store( + group="net", + package="model.net", + name="faditv2_7b", + node=FADITV2ConfigTrain, + ) + cs.store(group="net_ctrl", package="model.net_ctrl", name="faditv2_7b", node=FADITV2EncoderConfigTrain) + + +def register_conditioner_ctrlnet(cs): + cs.store( + group="conditioner", + package="model.conditioner", + name="ctrlnet", + node=BaseVideoConditionerWithCtrlConfig, + ) + cs.store( + group="conditioner", + package="model.conditioner", + name="ctrlnet_add_fps_image_size_padding_mask", + node=VideoConditionerFpsSizePaddingWithCtrlConfig, + ) def register_configs(): cs = ConfigStore.instance() - # This will register all the basic configs: net, conditioner, tokenizer. - base_registry.register_configs() + # register all the basic configs: net, conditioner, tokenizer. + register_net_train(cs) + register_conditioner(cs) + register_conditioner_ctrlnet(cs) + register_tokenizer(cs) - # This will register training configs: optimizer, scheduler, callbacks, etc. + # register training configs: optimizer, scheduler, callbacks, etc. base_training_registry.register_configs() - # following will register data, experiment, callbacks + # register data, experiment, callbacks register_data_ctrlnet(cs) - register_experiment_ctrlnet(cs) + + # register hint keys + for hint_key in CTRL_HINT_KEYS: + cs.store( + group="hint_key", + package="model", + name=hint_key, + node=dict(hint_key=dict(hint_key=hint_key, grayscale=False)), + ) + cs.store( + group="hint_key", + package="model", + name=f"{hint_key}_grayscale", + node=dict(hint_key=dict(hint_key=hint_key, grayscale=True)), + ) diff --git a/cosmos_transfer1/diffusion/config/training/tokenizer.py b/cosmos_transfer1/diffusion/config/training/tokenizer.py new file mode 100644 index 00000000..c580d722 --- /dev/null +++ b/cosmos_transfer1/diffusion/config/training/tokenizer.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import omegaconf + +from cosmos_transfer1.diffusion.training.modules.pretrained_vae import ( + JITVAE, + JointImageVideoSharedJITTokenizer, + VideoJITTokenizer, +) +from cosmos_transfer1.utils.lazy_config import LazyCall as L + +TOKENIZER_OPTIONS = {} + + +def tokenizer_register(key): + def decorator(func): + TOKENIZER_OPTIONS[key] = func + return func + + return decorator + + +@tokenizer_register("cosmos_diffusion_tokenizer_comp8x8x8") +def get_cosmos_diffusion_tokenizer_comp8x8x8(resolution: str, chunk_duration: int) -> omegaconf.dictconfig.DictConfig: + assert resolution in ["720"] + + pixel_chunk_duration = chunk_duration + temporal_compression_factor = 8 + spatial_compression_factor = 8 + + return L(JointImageVideoSharedJITTokenizer)( + video_vae=L(VideoJITTokenizer)( + name="cosmos_1_0_diffusion_tokenizer", + enc_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit", + dec_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit", + mean_std_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt", + latent_ch=16, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + temporal_compression_factor=temporal_compression_factor, + spatial_compression_factor=spatial_compression_factor, + spatial_resolution=resolution, + ), + image_vae=L(JITVAE)( + name="cosmos_1_0_diffusion_tokenizer", + enc_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit", + dec_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit", + mean_std_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt", + latent_ch=16, + is_image=False, + is_bf16=True, + ), + name="cosmos_1_0_diffusion_tokenizer", + latent_ch=16, + ) \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/config/transfer/model.py b/cosmos_transfer1/diffusion/config/transfer/model.py index 0fd68c24..6b97449e 100644 --- a/cosmos_transfer1/diffusion/config/transfer/model.py +++ b/cosmos_transfer1/diffusion/config/transfer/model.py @@ -27,6 +27,6 @@ class CtrlModelConfig(DefaultModelConfig): finetune_base_model: bool = False hint_mask: list = [True] hint_dropout_rate: float = 0.0 - num_control_blocks: int = 5 + num_control_blocks: int = 3 random_drop_control_blocks: bool = False pixel_corruptor: LazyDict = None diff --git a/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py b/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py index 4fb52ac7..c856a5fe 100644 --- a/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py +++ b/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py @@ -21,12 +21,10 @@ import os import warnings import traceback -from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np import torch from torch.utils.data import Dataset -from tqdm import tqdm from decord import VideoReader, cpu import pickle @@ -79,100 +77,40 @@ def __init__( self.ctrl_type = hint_key.lstrip("control_input_") self.ctrl_data_pth_config = CTRL_TYPE_INFO[self.ctrl_type] - # Set up directories + # Set up directories - only collect paths video_dir = os.path.join(self.dataset_dir, "videos") self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")] self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl") - print(f"{len(self.video_paths)} videos in total") - - # Initialize samples - self.samples = self._init_samples(self.video_paths) - self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0])) - print(f"{len(self.samples)} samples in total") + print(f"Finish initializing dataset with {len(self.video_paths)} videos in total.") # Set up preprocessing and augmentation - self.wrong_number = 0 - augmentor_name = f"video_ctrlnet_augmentor_{hint_key}" - # The augmentor will process the 'raw' control input data to the tensor, - # add it to the data dict, and resize both the video and the control input to the model's required input size - # self.augmentor = AUGMENTOR_OPTIONS[augmentor_name](resolution=resolution) augmentor_cfg = AUGMENTOR_OPTIONS[augmentor_name](resolution=resolution) - # Instantiate the augmentor configuration to get the actual augmentor objects (TODO (qianlim) remove instantiate here?) self.augmentor = {k: instantiate(v) for k, v in augmentor_cfg.items()} - def _init_samples(self, video_paths): - samples = [] - with ThreadPoolExecutor(32) as executor: - future_to_video_path = { - executor.submit(self._load_and_process_video_path, video_path): video_path - for video_path in video_paths - } - for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)): - samples.extend(future.result()) - return samples - - def _load_and_process_video_path(self, video_path): - """ - Function for sampling a chunk of self.sequence_length frames from the video. - Current version: randomly sample a chunk of self.sequence_length frames from the video each time. - Modify this for a different sampling strategy if needed. - """ + def _sample_frames(self, video_path): + """Sample frames from video and get metadata""" vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) n_frames = len(vr) - # Check if all required control files exist - ctrl_files_exist = True - video_name = os.path.basename(video_path).replace(".mp4", "") - - # load control input file if needed - if self.ctrl_data_pth_config["folder"] is not None: - ctrl_path = os.path.join( - self.dataset_dir, - self.ctrl_data_pth_config["folder"], - f"{video_name}.{self.ctrl_data_pth_config['format']}" - ) - if not os.path.exists(ctrl_path): - ctrl_files_exist = False - warnings.warn(f"Missing control input file: {ctrl_path}") - else: - ctrl_files_exist = True - - samples = [] - if not ctrl_files_exist: - return samples - - # Calculate the maximum possible starting frame that allows for a full sequence - max_start_idx = n_frames - self.sequence_length + # Calculate valid start frame range + max_start_idx = n_frames - self.sequence_length if max_start_idx < 0: # Video is too short - return samples - - # Randomly select a starting frame + return None, None, None + + # Sample start frame start_frame = np.random.randint(0, max_start_idx + 1) - - sample = dict( - video_path=video_path, - t5_embedding_path=os.path.join( - self.t5_dir, - os.path.basename(video_path).replace(".mp4", ".pickle"), - ) - ) - - if self.ctrl_data_pth_config["folder"] is not None: - sample["ctrl_path"] = os.path.join( - self.dataset_dir, - self.ctrl_data_pth_config["folder"], - f"{video_name}.{self.ctrl_data_pth_config['format']}" - ) - else: - sample["ctrl_path"] = None - - # Generate consecutive frame IDs - sample["frame_ids"] = list(range(start_frame, start_frame + self.sequence_length)) - # sample["chunk_index"] = 0 - samples.append(sample) + frame_ids = list(range(start_frame, start_frame + self.sequence_length)) - return samples + # Load frames + frames = vr.get_batch(frame_ids).asnumpy() + frames = frames.astype(np.uint8) + try: + fps = vr.get_avg_fps() + except Exception: # failed to read FPS + fps = 24 + + return frames, frame_ids, fps def _load_control_data(self, sample): """Load control data for the video clip.""" @@ -193,7 +131,7 @@ def _load_control_data(self, sample): vr = VideoReader(ctrl_path, ctx=cpu(0)) # Ensure the depth video has the same number of frames assert len(vr) >= frame_ids[-1] + 1, \ - f"Depth video {ctrl_data} has fewer frames than main video" + f"Depth video {ctrl_path} has fewer frames than main video" # Load the corresponding frames depth_frames = vr.get_batch(frame_ids).asnumpy() @@ -202,96 +140,89 @@ def _load_control_data(self, sample): "video": depth_frames, "frame_start": frame_ids[0], "frame_end": frame_ids[-1], - # "chunk_index": sample["chunk_index"] } except Exception as e: - warnings.warn(f"Failed to load control data from {ctrl_data}: {str(e)}") + warnings.warn(f"Failed to load control data from {ctrl_path}: {str(e)}") return None return data_dict - def _load_video(self, video_path, frame_ids): - vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) - assert (np.array(frame_ids) < len(vr)).all() - assert (np.array(frame_ids) >= 0).all() - vr.seek(0) - frame_data = vr.get_batch(frame_ids).asnumpy() - try: - fps = vr.get_avg_fps() - except Exception: # failed to read FPS - fps = 24 - return frame_data, fps - - def _get_frames(self, video_path, frame_ids): - frames, fps = self._load_video(video_path, frame_ids) - frames = frames.astype(np.uint8) - frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w) - return frames, fps - def __getitem__(self, index): - try: - sample = self.samples[index] - video_path = sample["video_path"] - frame_ids = sample["frame_ids"] - - data = dict() - - # Load video frames - video, fps = self._get_frames(video_path, frame_ids) - video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] - - aspect_ratio = detect_aspect_ratio((video.shape[3], video.shape[2])) # the func expects (w, h) - data["aspect_ratio"] = aspect_ratio - - # Basic data - data["video"] = video - data["video_name"] = { - "video_path": video_path, - "t5_embedding_path": sample["t5_embedding_path"], - "start_frame_id": str(frame_ids[0]), - } - - # Load T5 embeddings - with open(sample["t5_embedding_path"], "rb") as f: - t5_embedding = pickle.load(f)[0] - data["t5_text_embeddings"] = torch.from_numpy(t5_embedding).cuda() - data["t5_text_mask"] = torch.ones(512, dtype=torch.int64).cuda() - - # Add metadata - data["fps"] = fps - data["frame_start"] = frame_ids[0] - data["frame_end"] = frame_ids[-1] + 1 - # data["chunk_index"] = sample["chunk_index"] - data["image_size"] = torch.tensor([704, 1280, 704, 1280]).cuda() - data["num_frames"] = self.sequence_length - data["padding_mask"] = torch.zeros(1, 704, 1280).cuda() - - if self.ctrl_type: - ctrl_data = self._load_control_data(sample) - if ctrl_data is None: # Control data loading failed, discard this sample and reload another sample - return self[np.random.randint(len(self.samples))] - data.update(ctrl_data) - - # Apply augmentations including control input processing - for aug_name, aug_fn in self.augmentor.items(): - data = aug_fn(data) - - return data - - except Exception: - warnings.warn( - f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped " - f"(by randomly sampling another sample in the same dataset)." - ) - warnings.warn("FULL TRACEBACK:") - warnings.warn(traceback.format_exc()) - self.wrong_number += 1 - print(self.wrong_number) - return self[np.random.randint(len(self.samples))] + max_retries = 3 + for _ in range(max_retries): + try: + video_path = self.video_paths[index] + video_name = os.path.basename(video_path).replace(".mp4", "") + + # Sample frames + frames, frame_ids, fps = self._sample_frames(video_path) + if frames is None: # Invalid video or too short + index = np.random.randint(len(self.video_paths)) + continue + + data = dict() + + # Process video frames + video = torch.from_numpy(frames).permute(3, 0, 1, 2) # [T,H,W,C] -> [C,T,H,W] + aspect_ratio = detect_aspect_ratio((video.shape[3], video.shape[2])) # expects (w, h) + + # Basic data + data["video"] = video + data["aspect_ratio"] = aspect_ratio + data["video_name"] = { + "video_path": video_path, + "t5_embedding_path": os.path.join(self.t5_dir, f"{video_name}.pickle"), + "start_frame_id": str(frame_ids[0]), + } + + # Load T5 embeddings + with open(data["video_name"]["t5_embedding_path"], "rb") as f: + t5_embedding = pickle.load(f)[0] + data["t5_text_embeddings"] = torch.from_numpy(t5_embedding).cuda() + data["t5_text_mask"] = torch.ones(512, dtype=torch.int64).cuda() + + # Add metadata + data["fps"] = fps + data["frame_start"] = frame_ids[0] + data["frame_end"] = frame_ids[-1] + 1 + data["num_frames"] = self.sequence_length + data["image_size"] = torch.tensor([704, 1280, 704, 1280]).cuda() + data["padding_mask"] = torch.zeros(1, 704, 1280).cuda() + + if self.ctrl_type: + ctrl_data = self._load_control_data({ + "ctrl_path": os.path.join( + self.dataset_dir, + self.ctrl_data_pth_config["folder"], + f"{video_name}.{self.ctrl_data_pth_config['format']}" + ) if self.ctrl_data_pth_config["folder"] is not None else None, + "frame_ids": frame_ids + }) + if ctrl_data is None: # Control data loading failed + index = np.random.randint(len(self.video_paths)) + continue + data.update(ctrl_data) + + # Apply augmentations including control input processing + for aug_name, aug_fn in self.augmentor.items(): + data = aug_fn(data) + + return data + + except Exception: + warnings.warn( + f"Invalid data encountered: {self.video_paths[index]}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + if _ == max_retries - 1: + raise RuntimeError(f"Failed to load data after {max_retries} attempts") + index = np.random.randint(len(self.video_paths)) def __len__(self): - return len(self.samples) + return len(self.video_paths) def __str__(self): return f"{len(self.video_paths)} samples from {self.dataset_dir}" @@ -322,7 +253,7 @@ def __str__(self): f"{data['frame_end']=}\n" f"{data['video'].sum()=}\n" f"{data['video'].shape=}\n" - f"{data[control_input_key].shape}={data[control_input_key].shape}\n" # should match the video shape + f"{data[control_input_key].shape=}\n" # should match the video shape f"{data['video_name']=}\n" f"{data['t5_text_embeddings'].shape=}\n" "---" diff --git a/cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py b/cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py index 0b54a967..4e0d70fa 100644 --- a/cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py +++ b/cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py @@ -30,7 +30,7 @@ from cosmos_transfer1.diffusion.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported from cosmos_transfer1.diffusion.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported -from cosmos_transfer1.utils.config import make_freezable +from cosmos_transfer1.utils.ddp_config import make_freezable COMMON_SOLVER_OPTIONS = Literal["2ab", "2mid", "1euler"] diff --git a/cosmos_transfer1/diffusion/networks/general_dit.py b/cosmos_transfer1/diffusion/networks/general_dit.py index 06c00fa1..39d05229 100644 --- a/cosmos_transfer1/diffusion/networks/general_dit.py +++ b/cosmos_transfer1/diffusion/networks/general_dit.py @@ -15,30 +15,12 @@ """ A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. -It allows us easy to switch building blocks used and their order. Its instantiation includes -* transformer on fully flattened tokens -* factored spatial and temporal attention -* factored non-overlap spatial and temporal attention -* mixing of above attention types - -Limitations: - -* In favor of simplicity and cleanness, many ops are not fused and we can do better -* such as combining mutiple adaln MLPs into one inside one transformer block. -* we use reshape heavily, which may be not efficient when its occurs unnecessary CUDA memory copy - -Purpose: -* A prototype for testing different attention types and their combinations -* Idealy, we want to know where we should allocate our resources / FLOPS / memory via extensive empirical studies """ - -from collections.abc import Container from typing import List, Optional, Tuple import torch from einops import rearrange -from megatron.core import parallel_state from torch import nn from torch.distributed import ProcessGroup, get_process_group_ranks from torchvision import transforms @@ -46,96 +28,57 @@ from cosmos_transfer1.diffusion.conditioner import DataType from cosmos_transfer1.diffusion.module.attention import get_normalization from cosmos_transfer1.diffusion.module.blocks import ( - DITBuildingBlock, FinalLayer, GeneralDITTransformerBlock, PatchEmbed, - SDXLTimestepEmbedding, - SDXLTimesteps, -) -from cosmos_transfer1.diffusion.module.position_embedding import ( - LearnableEmb3D, - LearnableEmb3D_FPS_Aware, - LearnablePosEmbAxis, - SinCosPosEmb, - SinCosPosEmb_FPS_Aware, - SinCosPosEmbAxis, - VideoRopePosition3DEmb, - VideoRopePositionEmb, + TimestepEmbedding, + Timesteps, ) -from cosmos_transfer1.diffusion.training.tensor_parallel import gather_along_first_dim, scatter_along_first_dim +from cosmos_transfer1.diffusion.module.position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb from cosmos_transfer1.utils import log class GeneralDIT(nn.Module): """ A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. - Attributes: + + Args: max_img_h (int): Maximum height of the input images. max_img_w (int): Maximum width of the input images. max_frames (int): Maximum number of frames in the video sequence. in_channels (int): Number of input channels (e.g., RGB channels for color images). out_channels (int): Number of output channels. - patch_spatial (tuple of int): Spatial resolution of patches for input processing. + patch_spatial (tuple): Spatial resolution of patches for input processing. patch_temporal (int): Temporal resolution of patches for input processing. concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. - block_config (str): Configuration of the transformer block, e.g., 'FA-CA-MLP', means - full attention, cross attention, and MLP in sequence in one transformer block. + block_config (str): Configuration of the transformer block. See Notes for supported block types. model_channels (int): Base number of channels used throughout the model. - num_blocks (int): Number of residual blocks per resolution in the transformer. - num_heads (int): Number of heads in the multi-head self-attention layers. - spatial_attn_win_size (int): Window size for the spatial attention mechanism. - temporal_attn_win_size (int): Window size for the temporal attention mechanism. - mlp_ratio (float): Expansion ratio for the MLP (multi-layer perceptron) blocks in the transformer. - use_memory_save (bool): If True, utilizes checkpointing to reduce memory usage during training. (Deprecated) - use_checkpoint (bool): If True, utilizes checkpointing to reduce memory usage during training for all blocks. - crossattn_emb_channels (int): Number of embedding channels used in the cross-attention layers. - use_cross_attn_mask (bool): If True, applies a mask during cross-attention operations to manage sequence alignment. - pos_emb_cls (str): Type of positional embeddings used ('sincos' for sinusoidal or other types). - pos_emb_learnable (bool): Specifies if positional embeddings are learnable. - pos_emb_interpolation (str): Method used for interpolating positional embeddings, e.g., 'crop' for cropping adjustments. - block_x_format (str, optional): The format of the input tensor for the transformer block. Defaults to "BTHWD". Only support 'BTHWD' and 'THWBD'. - legacy_patch_emb (bool): If True, applies 3D convolutional layers for video inputs, otherwise, use Linear! This is for backward compatibility. - rope_h_extrapolation_ratio (float): Ratio of the height extrapolation for the rope positional embedding. - rope_w_extrapolation_ratio (float): Ratio of the width extrapolation for the rope positional embedding. - rope_t_extrapolation_ratio (float): Ratio of the temporal extrapolation for the rope positional embedding. - Note: - block_config support block type: - * spatial_sa, ssa: spatial self attention - * temporal_sa, tsa: temporal self attention - * cross_attn, ca: cross attention - * full_attn: full attention on all flatten tokens - * mlp, ff: feed forward block - * use '-' to separate different building blocks, e.g., 'FA-CA-MLP' means full attention, cross attention, and MLP in sequence in one transformer block. - - Example: - >>> # full attention, cross attention, and MLP - >>> option1_block_config = 'FA-CA-MLP' - >>> model_1 = GeneralDIT( - max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, - patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, - num_heads=16, mlp_ratio=4.0, - spatial_attn_win_size=1, temporal_attn_win_size=1, - block_config=option1_block_config - ) - >>> option2_block_config = 'SSA-CA-MLP-TSA-CA-MLP' - >>> model_2 = GeneralDIT( - max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, - patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, - num_heads=16, mlp_ratio=4.0, - spatial_attn_win_size=1, temporal_attn_win_size=1, - block_config=option2_block_config - ) - >>> # option3 model - >>> model_3 = GeneralDIT( - max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, - patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, - num_heads=16, mlp_ratio=4.0, - spatial_attn_win_size=1, temporal_attn_win_size=2, - block_config=option2_block_config - ) - >>> # Process input tensor through the model - >>> output = model(input_tensor) + num_blocks (int): Number of transformer blocks. + num_heads (int): Number of heads in the multi-head attention layers. + mlp_ratio (float): Expansion ratio for MLP blocks. + block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD'). + crossattn_emb_channels (int): Number of embedding channels for cross-attention. + use_cross_attn_mask (bool): Whether to use mask in cross-attention. + pos_emb_cls (str): Type of positional embeddings. + pos_emb_learnable (bool): Whether positional embeddings are learnable. + pos_emb_interpolation (str): Method for interpolating positional embeddings. + affline_emb_norm (bool): Whether to normalize affine embeddings. + use_adaln_lora (bool): Whether to use AdaLN-LoRA. + adaln_lora_dim (int): Dimension for AdaLN-LoRA. + rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. + rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. + rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. + extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. + extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings. + extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. + extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. + extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. + + Notes: + Supported block types in block_config: + * cross_attn, ca: Cross attention + * full_attn: Full attention on all flattened tokens + * mlp, ff: Feed forward block """ def __init__( @@ -153,13 +96,7 @@ def __init__( model_channels: int = 768, num_blocks: int = 10, num_heads: int = 16, - window_block_indexes: list = [], # index for window attention block - window_sizes: list = [], # window size for window attention block in the order of T, H, W - spatial_attn_win_size: int = 1, - temporal_attn_win_size: int = 1, mlp_ratio: float = 4.0, - use_memory_save: bool = False, - use_checkpoint: bool = False, block_x_format: str = "BTHWD", # cross attention settings crossattn_emb_channels: int = 1024, @@ -168,14 +105,9 @@ def __init__( pos_emb_cls: str = "sincos", pos_emb_learnable: bool = False, pos_emb_interpolation: str = "crop", - min_fps: int = 1, # 1 for getty video - max_fps: int = 30, # 120 for getty video but let's use 30 - additional_timestamp_channels: dict = None, # Follow SDXL, in format of {condition_name : dimension} affline_emb_norm: bool = False, # whether or not to normalize the affine embedding use_adaln_lora: bool = False, adaln_lora_dim: int = 256, - layer_mask: list = None, # whether or not a layer is used. For controlnet encoder - legacy_patch_emb: bool = True, rope_h_extrapolation_ratio: float = 1.0, rope_w_extrapolation_ratio: float = 1.0, rope_t_extrapolation_ratio: float = 1.0, @@ -184,6 +116,7 @@ def __init__( extra_h_extrapolation_ratio: float = 1.0, extra_w_extrapolation_ratio: float = 1.0, extra_t_extrapolation_ratio: float = 1.0, + layer_mask: list = None, # whether or not a layer is used. For controlnet encoder ) -> None: super().__init__() self.max_img_h = max_img_h @@ -202,11 +135,7 @@ def __init__( self.pos_emb_cls = pos_emb_cls self.pos_emb_learnable = pos_emb_learnable self.pos_emb_interpolation = pos_emb_interpolation - self.min_fps = min_fps - self.max_fps = max_fps - self.additional_timestamp_channels = additional_timestamp_channels self.affline_emb_norm = affline_emb_norm - self.legacy_patch_emb = legacy_patch_emb self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio @@ -219,23 +148,15 @@ def __init__( self.build_patch_embed() self.build_pos_embed() self.cp_group = None - self.sequence_parallel = getattr(parallel_state, "sequence_parallel", False) self.block_x_format = block_x_format self.use_adaln_lora = use_adaln_lora self.adaln_lora_dim = adaln_lora_dim self.t_embedder = nn.Sequential( - SDXLTimesteps(model_channels), - SDXLTimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), + Timesteps(model_channels), + TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), ) self.blocks = nn.ModuleDict() - self.block_config = block_config - self.use_memory_save = use_memory_save - self.use_checkpoint = use_checkpoint - - assert ( - len(window_block_indexes) == 0 or block_config == "FA-CA-MLP" - ), "Block config must be FA-CA-MLP if using a combination of window attention and global attention" layer_mask = [False] * num_blocks if layer_mask is None else layer_mask assert ( @@ -249,33 +170,21 @@ def __init__( context_dim=crossattn_emb_channels, num_heads=num_heads, block_config=block_config, - window_sizes=( - window_sizes if idx in window_block_indexes else [] - ), # There will be bug if using "WA-CA-MLP" mlp_ratio=mlp_ratio, - spatial_attn_win_size=spatial_attn_win_size, - temporal_attn_win_size=temporal_attn_win_size, x_format=self.block_x_format, use_adaln_lora=use_adaln_lora, adaln_lora_dim=adaln_lora_dim, - use_checkpoint=use_checkpoint, ) self.build_decode_head() - self.build_additional_timestamp_embedder() if self.affline_emb_norm: - log.critical("Building affine embedding normalization layer") + log.debug("Building affine embedding normalization layer") self.affline_norm = get_normalization("R", model_channels) else: self.affline_norm = nn.Identity() - self.init_weights() - - if self.use_memory_save: - log.critical("Using checkpointing to save memory! only verified in 14B base model training!") - for block in self.blocks.values(): - block.set_memory_save() + self.initialize_weights() - def init_weights(self): + def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): @@ -300,50 +209,6 @@ def _basic_init(module): if block.adaLN_modulation[-1].bias is not None: nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - # Tensor parallel - if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: - self.initialize_tensor_parallel_weights() - - def initialize_tensor_parallel_weights(self): - """ - Initialize weights for tensor parallel layers. - - This function performs the following steps: - 1. Retrieves the tensor parallel rank. - 2. Saves the current random state. - 3. Sets a new random seed based on the tensor parallel rank. - 4. Initializes weights for attention and MLP layers in each block. - 5. Restores the original random state. - - The use of different random seeds for each rank ensures - unique initializations across parallel processes. - """ - tp_rank = parallel_state.get_tensor_model_parallel_rank() - - # Save the current random state - rng_state = torch.get_rng_state() - - # Set a new random seed based on the tensor parallel rank - torch.manual_seed(tp_rank) - - for block in self.blocks.values(): - for layer in block.blocks: - if layer.block_type in ["full_attn", "fa", "cross_attn", "ca"]: - # Initialize weights for attention layers - torch.nn.init.xavier_uniform_(layer.block.attn.to_q[0].weight) - torch.nn.init.xavier_uniform_(layer.block.attn.to_k[0].weight) - torch.nn.init.xavier_uniform_(layer.block.attn.to_v[0].weight) - torch.nn.init.xavier_uniform_(layer.block.attn.to_out[0].weight) - elif layer.block_type in ["mlp", "ff"]: - # Initialize weights for MLP layers - torch.nn.init.xavier_uniform_(layer.block.layer1.weight) - torch.nn.init.xavier_uniform_(layer.block.layer2.weight) - else: - raise ValueError(f"Unknown block type {layer.block_type}") - - # Restore the original random state - torch.set_rng_state(rng_state) - def build_decode_head(self): self.final_layer = FinalLayer( hidden_size=self.model_channels, @@ -375,60 +240,20 @@ def build_patch_embed(self): in_channels=in_channels, out_channels=model_channels, bias=False, - keep_spatio=True, - legacy_patch_emb=self.legacy_patch_emb, ) - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) - if self.legacy_patch_emb: - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - - def build_additional_timestamp_embedder(self): - if self.additional_timestamp_channels: - self.additional_timestamp_embedder = nn.ModuleDict() - for cond_name, cond_emb_channels in self.additional_timestamp_channels.items(): - log.critical( - f"Building additional timestamp embedder for {cond_name} with {cond_emb_channels} channels" - ) - self.additional_timestamp_embedder[cond_name] = nn.Sequential( - SDXLTimesteps(cond_emb_channels), - SDXLTimestepEmbedding(cond_emb_channels, cond_emb_channels), - ) - - def prepare_additional_timestamp_embedder(self, **kwargs): - condition_concat = [] - - for cond_name, embedder in self.additional_timestamp_embedder.items(): - condition_concat.append(embedder(kwargs[cond_name])[0]) - embedding = torch.cat(condition_concat, dim=1) - if embedding.shape[1] < self.model_channels: - embedding = nn.functional.pad(embedding, (0, self.model_channels - embedding.shape[1])) - return embedding def build_pos_embed(self): - if self.pos_emb_cls == "sincos": - cls_type = SinCosPosEmb - elif self.pos_emb_cls == "learnable": - cls_type = LearnableEmb3D - elif self.pos_emb_cls == "sincos_fps_aware": - cls_type = SinCosPosEmb_FPS_Aware - elif self.pos_emb_cls == "learnable_fps_aware": - cls_type = LearnableEmb3D_FPS_Aware - elif self.pos_emb_cls == "rope": - cls_type = VideoRopePositionEmb - elif self.pos_emb_cls == "rope3d": + if self.pos_emb_cls == "rope3d": cls_type = VideoRopePosition3DEmb else: raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") - log.critical(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + log.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") kwargs = dict( model_channels=self.model_channels, len_h=self.max_img_h // self.patch_spatial, len_w=self.max_img_w // self.patch_spatial, len_t=self.max_frames // self.patch_temporal, - max_fps=self.max_fps, - min_fps=self.min_fps, is_learnable=self.pos_emb_learnable, interpolation=self.pos_emb_interpolation, head_dim=self.model_channels // self.num_heads, @@ -442,20 +267,14 @@ def build_pos_embed(self): if self.extra_per_block_abs_pos_emb: assert self.extra_per_block_abs_pos_emb_type in [ - "sincos", "learnable", ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio - if self.extra_per_block_abs_pos_emb_type == "sincos": - self.extra_pos_embedder = SinCosPosEmbAxis( - **kwargs, - ) - elif self.extra_per_block_abs_pos_emb_type == "learnable": - self.extra_pos_embedder = LearnablePosEmbAxis( - **kwargs, - ) + self.extra_pos_embedder = LearnablePosEmbAxis( + **kwargs, + ) def prepare_embedded_sequence( self, @@ -485,8 +304,8 @@ def prepare_embedded_sequence( - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using the `self.pos_embedder` with the shape [T, H, W]. - - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` - with the fps tensor. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the + `self.pos_embedder` with the fps tensor. - Otherwise, the positional embeddings are generated without considering fps. """ if self.concat_padding_mask: @@ -510,6 +329,7 @@ def prepare_embedded_sequence( x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] else: x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + return x_B_T_H_W_D, None, extra_pos_emb def decoder_head( @@ -587,30 +407,10 @@ def forward_before_blocks( if scalar_feature is not None: raise NotImplementedError("Scalar feature is not implemented yet.") - timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) - - if self.additional_timestamp_channels: - additional_cond_B_D = self.prepare_additional_timestamp_embedder( - bs=x.shape[0], - fps=fps, - h=image_size[:, 0], - w=image_size[:, 1], - org_h=image_size[:, 2], - org_w=image_size[:, 3], - ) - - affline_emb_B_D += additional_cond_B_D - affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() affline_emb_B_D = self.affline_norm(affline_emb_B_D) - # for logging purpose - self.affline_scale_log_info = affline_scale_log_info - self.affline_emb = affline_emb_B_D - self.crossattn_emb = crossattn_emb - self.crossattn_mask = crossattn_mask - if self.use_cross_attn_mask: crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] else: @@ -627,24 +427,6 @@ def forward_before_blocks( if crossattn_mask: crossattn_mask = rearrange(crossattn_mask, "B M -> M B") - if self.sequence_parallel: - tp_group = parallel_state.get_tensor_model_parallel_group() - # Sequence parallel requires the input tensor to be scattered along the first dimension. - assert self.block_config == "FA-CA-MLP" # Only support this block config for now - T, H, W, B, D = x.shape - # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer - x = x.view(T * H * W, 1, 1, B, D) - assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 - x = scatter_along_first_dim(x, tp_group) - - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( - T * H * W, 1, 1, B, D - ) - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group - ) - elif self.blocks["block0"].x_format == "BTHWD": x = x_B_T_H_W_D else: @@ -661,199 +443,6 @@ def forward_before_blocks( } return output - def forward_blocks_regular( - self, - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D, - adaln_lora_B_3D, - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - feature_indices, - original_shape, - x_ctrl, - return_features_early, - ): - features = [] - for name, block in self.blocks.items(): - assert ( - self.blocks["block0"].x_format == block.x_format - ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" - x = block( - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - adaln_lora_B_3D=adaln_lora_B_3D, - extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - ) - - # Extract features - block_idx = int(name.split("block")[-1]) - if block_idx in feature_indices: - B, C, T, H, W = original_shape - H = H // self.patch_spatial - W = W // self.patch_spatial - T = T // self.patch_temporal - if self.sequence_parallel: - x_feat = gather_along_first_dim(x, parallel_state.get_tensor_model_parallel_group()) - x_B_T_H_W_D = rearrange(x_feat, "(T H W) 1 1 B D -> B T H W D", T=T, H=H, W=W) - else: - x_feat = x - if self.blocks["block0"].x_format == "THWBD": - x_B_T_H_W_D = rearrange(x_feat, "T H W B D -> B T H W D", T=T, H=H, W=W) - elif self.blocks["block0"].x_format == "BTHWD": - x_B_T_H_W_D = x_feat - else: - raise ValueError(f"Unknown x_format {self.blocks[-1].x_format}") - - features.append(x_B_T_H_W_D) - - if x_ctrl is not None and name in x_ctrl: - x = x + x_ctrl[name] - # If we have all of the features, we can exit early - if return_features_early and len(features) == len(feature_indices): - return features - - if self.blocks["block0"].x_format == "THWBD": - x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") - elif self.blocks["block0"].x_format == "BTHWD": - x_B_T_H_W_D = x - else: - raise ValueError(f"Unknown x_format {self.blocks[-1].x_format}") - - x_B_D_T_H_W = self.decoder_head( - x_B_T_H_W_D=x_B_T_H_W_D, - emb_B_D=affline_emb_B_D, - crossattn_emb=None, - origin_shape=original_shape, - crossattn_mask=None, - adaln_lora_B_3D=adaln_lora_B_3D, - ) - - if len(feature_indices) == 0: - # no features requested, return only the model output - return x_B_D_T_H_W - else: - # score and features; score, features - return x_B_D_T_H_W, features - - def forward_blocks_memory_save( - self, - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D, - adaln_lora_B_3D, - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - feature_indices, - original_shape, - x_ctrl, - return_features_early, - ): - x_before_gate = 0 - x_skip = rearrange(x, "T H W B D -> (T H W) B D") - assert self.blocks["block0"].x_format == "THWBD" - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_per_block_pos_emb = rearrange(extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "T H W B D -> (T H W) B D") - else: - extra_per_block_pos_emb = None - gate_L_B_D = 1.0 - - features = [] - for name, block in self.blocks.items(): - gate_L_B_D, x_before_gate, x_skip = block( - x_before_gate, - x_skip, - gate_L_B_D, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - adaln_lora_B_3D=adaln_lora_B_3D, - extra_per_block_pos_emb=extra_per_block_pos_emb, - ) - - # Extract features. - # Convert the block index in the memory save mode to the block index in the regular mode. - block_idx = int(name.split("block")[-1]) - 1 - if block_idx in feature_indices: - B, C, T_before_patchify, H_before_patchify, W_before_patchify = original_shape - H = H_before_patchify // self.patch_spatial - W = W_before_patchify // self.patch_spatial - T = T_before_patchify // self.patch_temporal - if self.sequence_parallel: - x_feat = gather_along_first_dim(x_skip, parallel_state.get_tensor_model_parallel_group()) - x_B_T_H_W_D = rearrange(x_feat, "(T H W) 1 1 B D -> B T H W D", T=T, H=H, W=W) - else: - x_feat = x_skip - x_B_T_H_W_D = rearrange(x_feat, "(T H W) B D -> B T H W D", T=T, H=H, W=W) - - features.append(x_B_T_H_W_D) - - new_name = f"block{block_idx}" - if x_ctrl is not None and new_name in x_ctrl: - x_ctrl_ = x_ctrl[new_name] - x_ctrl_ = rearrange(x_ctrl_, "T H W B D -> (T H W) B D") - x_skip = x_skip + x_ctrl_ - # If we have all of the features, we can exit early - if return_features_early and len(features) == len(feature_indices): - return features - - x_THW_B_D_before_gate = x_before_gate - x_THW_B_D_skip = x_skip - - B, C, T_before_patchify, H_before_patchify, W_before_patchify = original_shape - x_BT_HW_D_before_gate = rearrange( - x_THW_B_D_before_gate, - "(T H W) B D -> (B T) (H W) D", - T=T_before_patchify // self.patch_temporal, - H=H_before_patchify // self.patch_spatial, - W=W_before_patchify // self.patch_spatial, - ) - x_BT_HW_D_skip = rearrange( - x_THW_B_D_skip, - "(T H W) B D -> (B T) (H W) D", - T=T_before_patchify // self.patch_temporal, - H=H_before_patchify // self.patch_spatial, - W=W_before_patchify // self.patch_spatial, - ) - - x_BT_HW_D = self.final_layer.forward_with_memory_save( - x_BT_HW_D_before_gate=x_BT_HW_D_before_gate, - x_BT_HW_D_skip=x_BT_HW_D_skip, - gate_L_B_D=gate_L_B_D, - emb_B_D=affline_emb_B_D, - adaln_lora_B_3D=adaln_lora_B_3D, - ) - - # This is to ensure x_BT_HW_D has the correct shape because - # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). - x_BT_HW_D = x_BT_HW_D.view( - B * T_before_patchify // self.patch_temporal, - H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, - -1, - ) - x_B_D_T_H_W = rearrange( - x_BT_HW_D, - "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", - p1=self.patch_spatial, - p2=self.patch_spatial, - H=H_before_patchify // self.patch_spatial, - W=W_before_patchify // self.patch_spatial, - t=self.patch_temporal, - B=B, - ) - if len(feature_indices) == 0: - # no features requested, return only the model output - return x_B_D_T_H_W - else: - # score and features; score, features - return x_B_D_T_H_W, features - def forward( self, x: torch.Tensor, @@ -861,16 +450,13 @@ def forward( crossattn_emb: torch.Tensor, crossattn_mask: Optional[torch.Tensor] = None, fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, scalar_feature: Optional[torch.Tensor] = None, data_type: Optional[DataType] = DataType.VIDEO, - x_ctrl: Optional[dict] = None, latent_condition: Optional[torch.Tensor] = None, latent_condition_sigma: Optional[torch.Tensor] = None, - feature_indices: Optional[Container[int]] = None, - return_features_early: bool = False, condition_video_augment_sigma: Optional[torch.Tensor] = None, + x_ctrl: Optional[dict] = None, **kwargs, ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: """ @@ -879,19 +465,10 @@ def forward( timesteps: (B, ) tensor of timesteps crossattn_emb: (B, N, D) tensor of cross-attention embeddings crossattn_mask: (B, N) tensor of cross-attention masks - feature_indices: A set of feature indices (a set of integers) decides which blocks - to extract features from. If the set is non-empty, then features will be returned. - By default, feature_indices=None means extract no features. - return_features_early: If true, the forward pass returns the features once the set is complete. - This means the forward pass will not finish completely and no final output is returned. - condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to augment condition input, the lvg model will condition on the condition_video_augment_sigma value; + condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to + augment condition input, the lvg model will condition on the condition_video_augment_sigma value; we need forward_before_blocks pass to the forward_before_blocks function. """ - if feature_indices is None: - feature_indices = {} - if return_features_early and len(feature_indices) == 0: - # Exit immediately if user requested this. - return [] inputs = self.forward_before_blocks( x=x, @@ -899,7 +476,6 @@ def forward( crossattn_emb=crossattn_emb, crossattn_mask=crossattn_mask, fps=fps, - image_size=image_size, padding_mask=padding_mask, scalar_feature=scalar_feature, data_type=data_type, @@ -923,38 +499,35 @@ def forward( x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" - if self.use_memory_save: - return self.forward_blocks_memory_save( + for name, block in self.blocks.items(): + assert ( + self.blocks["block0"].x_format == block.x_format + ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" + + x = block( x, affline_emb_B_D, crossattn_emb, crossattn_mask, - rope_emb_L_1_1_D, - adaln_lora_B_3D, - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - feature_indices, - original_shape, - x_ctrl, - return_features_early, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, ) + if x_ctrl is not None and name in x_ctrl: + x = x + x_ctrl[name] - return self.forward_blocks_regular( - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D, - adaln_lora_B_3D, - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - feature_indices, - original_shape, - x_ctrl, - return_features_early, + x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") + + x_B_D_T_H_W = self.decoder_head( + x_B_T_H_W_D=x_B_T_H_W_D, + emb_B_D=affline_emb_B_D, + crossattn_emb=None, + origin_shape=original_shape, + crossattn_mask=None, + adaln_lora_B_3D=adaln_lora_B_3D, ) - @property - def fsdp_wrap_block_cls(self): - return DITBuildingBlock + return x_B_D_T_H_W def enable_context_parallel(self, cp_group: ProcessGroup): cp_ranks = get_process_group_ranks(cp_group) @@ -1001,29 +574,6 @@ def disable_context_parallel(self): log.debug("[CP] Disable context parallelism.") - def enable_sequence_parallel(self): - self._set_sequence_parallel(True) - - def disable_sequence_parallel(self): - self._set_sequence_parallel(False) - - def _set_sequence_parallel(self, status: bool): - self.sequence_parallel = status - self.final_layer.sequence_parallel = status - for block in self.blocks.values(): - for layer in block.blocks: - if layer.block_type in ["full_attn", "fa", "cross_attn", "ca"]: - layer.block.attn.to_q[0].sequence_parallel = status - layer.block.attn.to_k[0].sequence_parallel = status - layer.block.attn.to_v[0].sequence_parallel = status - layer.block.attn.to_out[0].sequence_parallel = status - layer.block.attn.attn_op.sequence_parallel = status - elif layer.block_type in ["mlp", "ff"]: - layer.block.layer1.sequence_parallel = status - layer.block.layer2.sequence_parallel = status - else: - raise ValueError(f"Unknown block type {layer.block_type}") - @property def is_context_parallel_enabled(self): - return self.cp_group is not None + return self.cp_group is not None \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py b/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py index c66a4488..b6b9e04a 100644 --- a/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py +++ b/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py @@ -19,11 +19,10 @@ from typing import List, Optional, Tuple -import numpy as np import torch from einops import rearrange -from megatron.core import parallel_state +# from megatron.core import parallel_state from torch import nn from torchvision import transforms @@ -31,7 +30,6 @@ from cosmos_transfer1.diffusion.module.blocks import PatchEmbed, zero_module from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp from cosmos_transfer1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT as GeneralDIT -from cosmos_transfer1.diffusion.training.tensor_parallel import scatter_along_first_dim class GeneralDITEncoder(GeneralDIT): @@ -62,7 +60,7 @@ def __init__(self, *args, **kwargs): input_hint_block += [nn.Linear(hint_nf[i], hint_nf[i + 1]), nonlinearity] self.input_hint_block = nn.Sequential(*input_hint_block) # Initialize weights - self.init_weights() + self.initialize_weights() self.zero_blocks = nn.ModuleDict() for idx in range(num_blocks): if layer_mask[idx]: @@ -70,11 +68,6 @@ def __init__(self, *args, **kwargs): self.zero_blocks[f"block{idx}"] = zero_module(nn.Linear(model_channels, model_channels)) self.input_hint_block.append(zero_module(nn.Linear(hint_nf[-1], model_channels))) - def _set_sequence_parallel(self, status: bool): - self.zero_blocks.sequence_parallel = status - self.input_hint_block.sequence_parallel = status - super()._set_sequence_parallel(status) - def build_hint_patch_embed(self): concat_padding_mask, in_channels, patch_spatial, patch_temporal, model_channels = ( self.concat_padding_mask, @@ -90,15 +83,8 @@ def build_hint_patch_embed(self): in_channels=in_channels, out_channels=model_channels, bias=False, - keep_spatio=True, - legacy_patch_emb=self.legacy_patch_emb, ) - if self.legacy_patch_emb: - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) - w = self.x_embedder2.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - def prepare_hint_embedded_sequence( self, x_B_C_T_H_W: torch.Tensor, fps: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -139,13 +125,7 @@ def encode_hint( ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." hint_B_T_H_W_D, _ = self.prepare_hint_embedded_sequence(hint, fps=fps, padding_mask=padding_mask) - hint = rearrange(hint_B_T_H_W_D, "B T H W D -> T H W B D") - if self.sequence_parallel: - tp_group = parallel_state.get_tensor_model_parallel_group() - T, H, W, B, D = hint.shape - hint = hint.view(T * H * W, 1, 1, B, -1) - hint = scatter_along_first_dim(hint, tp_group) guided_hint = self.input_hint_block(hint) return guided_hint @@ -157,7 +137,6 @@ def forward( crossattn_emb: torch.Tensor, crossattn_mask: Optional[torch.Tensor] = None, fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, scalar_feature: Optional[torch.Tensor] = None, data_type: Optional[DataType] = DataType.VIDEO, @@ -190,7 +169,6 @@ def forward( crossattn_emb=crossattn_emb_input, crossattn_mask=crossattn_mask_input, fps=fps, - image_size=image_size, padding_mask=padding_mask, scalar_feature=scalar_feature, data_type=data_type, @@ -223,7 +201,6 @@ def forward( ) input_list = [x, condition_video_input_mask] x = torch.cat(input_list, dim=1) - elif data_type == DataType.IMAGE: # For image, we dont have condition_video_input_mask, or condition_video_pose # We need to add the extra channel for video condition mask @@ -246,30 +223,20 @@ def forward( outs = {} - # (Experimental, not used in the released model) if also training base model, sometimes drop the - # controlnet branch to only train base branch. This is to prevent the network become dependent on - # controlnet branch and make control weight useless. - is_training = torch.is_grad_enabled() - is_training_base_model = any(p.requires_grad for p in base_model.parameters()) - if is_training and is_training_base_model: - coin_flip = torch.rand(B).to(x.device) > self.dropout_ctrl_branch # prob for only training base model - if self.blocks["block0"].x_format == "THWBD": - coin_flip = coin_flip[None, None, None, :, None] - elif self.blocks["block0"].x_format == "BTHWD": - coin_flip = coin_flip[:, None, None, None, None] - else: - coin_flip = 1 - num_control_blocks = self.layer_mask.index(True) num_layers_to_use = num_control_blocks control_gate_per_layer = [i < num_layers_to_use for i in range(num_control_blocks)] - + if isinstance(control_weight, torch.Tensor): if control_weight.ndim == 0: # Single scalar tensor - control_weight = [float(control_weight)] * len(guided_hints) + control_weight = [float(control_weight)] elif control_weight.ndim == 1: # List of scalar weights control_weight = [float(w) for w in control_weight] else: # Spatial-temporal weight maps + if self.cp_group is not None: + control_weight = split_inputs_cp( + control_weight, seq_dim=3, cp_group=self.cp_group + ) control_weight = [w for w in control_weight] # Keep as tensor else: control_weight = [control_weight] * len(guided_hints) @@ -303,19 +270,6 @@ def forward( if scalar_feature is not None: raise NotImplementedError("Scalar feature is not implemented yet.") - if self.additional_timestamp_channels: - additional_cond_B_D = self.prepare_additional_timestamp_embedder( - bs=x.shape[0], - fps=fps, - h=image_size[:, 0], - w=image_size[:, 1], - org_h=image_size[:, 2], - org_w=image_size[:, 3], - ) - - affline_emb_B_D += additional_cond_B_D - affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() - affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() affline_emb_B_D = affline_norm(affline_emb_B_D) @@ -329,23 +283,6 @@ def forward( extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" ) - if self.sequence_parallel: - tp_group = parallel_state.get_tensor_model_parallel_group() - # Sequence parallel requires the input tensor to be scattered along the first dimension. - assert self.block_config == "FA-CA-MLP" # Only support this block config for now - T, H, W, B, D = x.shape - # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer - x = x.view(T * H * W, 1, 1, B, D) - assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 - x = scatter_along_first_dim(x, tp_group) - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( - T * H * W, 1, 1, B, D - ) - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group - ) - for idx, (name, block) in enumerate(blocks.items()): assert ( blocks["block0"].x_format == block.x_format @@ -365,20 +302,14 @@ def forward( gate = control_gate_per_layer[idx] if isinstance(control_weight[i], (float, int)) or control_weight[i].ndim < 2: - hint_val = zero_blocks[name](x) * control_weight[i] * coin_flip * gate + hint_val = zero_blocks[name](x) * control_weight[i] * gate else: # Spatial-temporal weights [num_controls, B, 1, T, H, W] control_feat = zero_blocks[name](x) # Get current feature dimensions weight_map = control_weight[i] # [B, 1, T, H, W] # Reshape to match THWBD format weight_map = weight_map.permute(2, 3, 4, 0, 1) # [T, H, W, B, 1] - weight_map = weight_map.view(T * H * W, 1, 1, B, 1) - - if self.sequence_parallel: - weight_map = scatter_along_first_dim(weight_map, tp_group) - - hint_val = control_feat * weight_map * coin_flip * gate - + hint_val = control_feat * weight_map * gate if name not in outs: outs[name] = hint_val else: @@ -390,7 +321,6 @@ def forward( crossattn_emb=crossattn_emb_input, crossattn_mask=crossattn_mask_input, fps=fps, - image_size=image_size, padding_mask=padding_mask, scalar_feature=scalar_feature, data_type=data_type, @@ -398,4 +328,4 @@ def forward( condition_video_input_mask=condition_video_input_mask_input, **kwargs, ) - return output + return output \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py b/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py index c22aad4c..c1c13285 100644 --- a/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py +++ b/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py @@ -21,10 +21,9 @@ from torch import nn from cosmos_transfer1.diffusion.conditioner import DataType -from cosmos_transfer1.diffusion.module.blocks import SDXLTimesteps, SDXLTimestepEmbedding +from cosmos_transfer1.diffusion.module.blocks import TimestepEmbedding, Timesteps from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp from cosmos_transfer1.diffusion.networks.general_dit import GeneralDIT -from cosmos_transfer1.diffusion.training.tensor_parallel import scatter_along_first_dim from cosmos_transfer1.utils import log @@ -34,18 +33,18 @@ def __init__(self, *args, in_channels=16 + 1, add_augment_sigma_embedding=False, # extra channel for video condition mask super().__init__(*args, in_channels=in_channels, **kwargs) - log.info(f"VideoExtendGeneralDIT in_channels: {in_channels}") + log.debug(f"VideoExtendGeneralDIT in_channels: {in_channels}") def build_additional_timestamp_embedder(self): super().build_additional_timestamp_embedder() if self.add_augment_sigma_embedding: log.info("Adding augment sigma embedding") self.augment_sigma_embedder = nn.Sequential( - SDXLTimesteps(self.model_channels), - SDXLTimestepEmbedding(self.model_channels, self.model_channels, use_adaln_lora=self.use_adaln_lora), + Timesteps(self.model_channels), + TimestepEmbedding(self.model_channels, self.model_channels, use_adaln_lora=self.use_adaln_lora), ) - def init_weights(self): + def initialize_weights(self): if self.add_augment_sigma_embedding: # Initialize timestep embedding for augment sigma nn.init.normal_(self.augment_sigma_embedder[1].linear_1.weight, std=0.02) @@ -55,7 +54,7 @@ def init_weights(self): if self.augment_sigma_embedder[1].linear_2.bias is not None: nn.init.constant_(self.augment_sigma_embedder[1].linear_2.bias, 0) - super().init_weights() # Call this last since it wil call TP weight init + super().initialize_weights() # Call this last since it wil call TP weight init def forward( self, @@ -64,7 +63,6 @@ def forward( crossattn_emb: torch.Tensor, crossattn_mask: Optional[torch.Tensor] = None, fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, scalar_feature: Optional[torch.Tensor] = None, data_type: Optional[DataType] = DataType.VIDEO, @@ -72,56 +70,49 @@ def forward( condition_video_indicator: Optional[torch.Tensor] = None, condition_video_input_mask: Optional[torch.Tensor] = None, condition_video_augment_sigma: Optional[torch.Tensor] = None, - condition_video_pose: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: - """Args: - condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation - condition_video_pose: (B, 1, T, H, W) tensor of pose condition + """Forward pass of the video-conditioned DIT model. + + Args: + x: Input tensor of shape (B, C, T, H, W) + timesteps: Timestep tensor of shape (B,) + crossattn_emb: Cross attention embeddings of shape (B, N, D) + crossattn_mask: Optional cross attention mask of shape (B, N) + fps: Optional frames per second tensor + padding_mask: Optional padding mask tensor + scalar_feature: Optional scalar features tensor + data_type: Type of data being processed (default: DataType.VIDEO) + video_cond_bool: Optional video conditioning boolean tensor + condition_video_indicator: Optional video condition indicator tensor + condition_video_input_mask: Required mask tensor for video data type + condition_video_augment_sigma: Optional sigma values for conditional input augmentation + **kwargs: Additional keyword arguments + + Returns: + torch.Tensor: Output tensor """ B, C, T, H, W = x.shape if data_type == DataType.VIDEO: - assert ( - condition_video_input_mask is not None - ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" - if self.cp_group is not None: - condition_video_input_mask = split_inputs_cp( - condition_video_input_mask, seq_dim=2, cp_group=self.cp_group - ) - condition_video_indicator = split_inputs_cp( - condition_video_indicator, seq_dim=2, cp_group=self.cp_group - ) - if condition_video_pose is not None: - condition_video_pose = split_inputs_cp(condition_video_pose, seq_dim=2, cp_group=self.cp_group) + assert condition_video_input_mask is not None, "condition_video_input_mask is required for video data type" + if parallel_state.is_initialized(): + cp_group = parallel_state.get_context_parallel_group() + condition_video_input_mask = split_inputs_cp(condition_video_input_mask, seq_dim=2, cp_group=cp_group) + condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) input_list = [x, condition_video_input_mask] - if condition_video_pose is not None: - if condition_video_pose.shape[2] > T: - log.warning( - f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" - ) - condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() - input_list.append(condition_video_pose) x = torch.cat( input_list, dim=1, ) - if data_type == DataType.IMAGE: - # For image, we dont have condition_video_input_mask, or condition_video_pose - # We need to add the extra channel for video condition mask - padding_channels = self.in_channels - x.shape[1] - x = torch.cat([x, torch.zeros((B, padding_channels, T, H, W), dtype=x.dtype, device=x.device)], dim=1) - else: - assert x.shape[1] == self.in_channels, f"Expected {self.in_channels} channels, got {x.shape[1]}" return super().forward( x=x, timesteps=timesteps, crossattn_emb=crossattn_emb, crossattn_mask=crossattn_mask, fps=fps, - image_size=image_size, padding_mask=padding_mask, scalar_feature=scalar_feature, data_type=data_type, @@ -136,7 +127,6 @@ def forward_before_blocks( crossattn_emb: torch.Tensor, crossattn_mask: Optional[torch.Tensor] = None, fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, scalar_feature: Optional[torch.Tensor] = None, data_type: Optional[DataType] = DataType.VIDEO, @@ -175,77 +165,33 @@ def forward_before_blocks( if scalar_feature is not None: raise NotImplementedError("Scalar feature is not implemented yet.") - timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) - if self.additional_timestamp_channels: - additional_cond_B_D = self.prepare_additional_timestamp_embedder( - bs=x.shape[0], - fps=fps, - h=image_size[:, 0], - w=image_size[:, 1], - org_h=image_size[:, 2], - org_w=image_size[:, 3], - ) - affline_emb_B_D += additional_cond_B_D - affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() if self.add_augment_sigma_embedding: if condition_video_augment_sigma is None: # Handling image case - # Note: for video case, when there is not condition frames, we also set it as zero, see - # the augment_conditional_latent_frames function in DiffusionV2WModel and ExtendDiffusionModel. + # Note: for video case, when there is not condition frames, we also set it as zero, see extend_model augment_conditional_latent_frames function assert data_type == DataType.IMAGE, "condition_video_augment_sigma is required for video data type" condition_video_augment_sigma = torch.zeros_like(timesteps.flatten()) - affline_augment_sigma_emb_B_D, adaln_lora_sigma_emb_B_3D = self.augment_sigma_embedder( - condition_video_augment_sigma.flatten() - ) + affline_augment_sigma_emb_B_D, _ = self.augment_sigma_embedder(condition_video_augment_sigma.flatten()) affline_emb_B_D = affline_emb_B_D + affline_augment_sigma_emb_B_D affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() affline_emb_B_D = self.affline_norm(affline_emb_B_D) - # for logging purpose - self.affline_scale_log_info = affline_scale_log_info - self.affline_emb = affline_emb_B_D - self.crossattn_emb = crossattn_emb - self.crossattn_mask = crossattn_mask - if self.use_cross_attn_mask: crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] else: crossattn_mask = None - if self.blocks["block0"].x_format == "THWBD": - x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" - ) - crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") - if crossattn_mask: - crossattn_mask = rearrange(crossattn_mask, "B M -> M B") - - if self.sequence_parallel: - tp_group = parallel_state.get_tensor_model_parallel_group() - # Sequence parallel requires the input tensor to be scattered along the first dimension. - assert self.block_config == "FA-CA-MLP" # Only support this block config for now - T, H, W, B, D = x.shape - # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer - x = x.view(T * H * W, 1, 1, B, D) - assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 - x = scatter_along_first_dim(x, tp_group) - - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( - T * H * W, 1, 1, B, D - ) - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group - ) + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") - elif self.blocks["block0"].x_format == "BTHWD": - x = x_B_T_H_W_D - else: - raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") output = { "x": x, "affline_emb_B_D": affline_emb_B_D, @@ -256,4 +202,4 @@ def forward_before_blocks( "original_shape": original_shape, "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, } - return output + return output \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/training/models/model_image.py b/cosmos_transfer1/diffusion/training/models/model_image.py index ce4edc00..203024e6 100644 --- a/cosmos_transfer1/diffusion/training/models/model_image.py +++ b/cosmos_transfer1/diffusion/training/models/model_image.py @@ -83,7 +83,7 @@ def __init__(self, config): # 3. vae with misc.timer("DiffusionModel: set_up_vae"): - self.vae: BaseVAE = lazy_instantiate(config.vae) + self.vae: BaseVAE = lazy_instantiate(config.tokenizer) assert ( self.vae.latent_ch == self.state_shape[0] ), f"latent_ch {self.vae.latent_ch} != state_shape {self.state_shape[0]}" diff --git a/cosmos_transfer1/diffusion/training/modules/blocks.py b/cosmos_transfer1/diffusion/training/modules/blocks.py new file mode 100644 index 00000000..cc05a54f --- /dev/null +++ b/cosmos_transfer1/diffusion/training/modules/blocks.py @@ -0,0 +1,1118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from megatron.core import parallel_state +from torch import nn +from transformer_engine.pytorch.attention import apply_rotary_pos_emb + +from cosmos_transfer1.diffusion.module.attention import Attention, GPT2FeedForward +from cosmos_transfer1.diffusion.training.tensor_parallel import gather_along_first_dim +from cosmos_transfer1.utils import log + + +class SDXLTimesteps(nn.Module): + def __init__(self, num_channels: int = 320): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps): + in_dype = timesteps.dtype + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return emb.to(in_dype) + + +class SDXLTimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False): + super().__init__() + log.critical( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False) + else: + self.linear_2 = nn.Linear(out_features, out_features, bias=True) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_3D = emb + emb_B_D = sample + else: + emb_B_D = emb + adaln_lora_B_3D = None + + return emb_B_D, adaln_lora_B_3D + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class PatchEmbed(nn.Module): + """ + PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, + depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, + making it suitable for video and image processing tasks. It supports dividing the input into patches and embedding each + patch into a vector of size `out_channels`. + + Parameters: + - spatial_patch_size (int): The size of each spatial patch. + - temporal_patch_size (int): The size of each temporal patch. + - in_channels (int): Number of input channels. Default: 3. + - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. + - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. + - keep_spatio (bool): If True, the spatial dimensions are kept separate in the output tensor, otherwise, they are flattened. Default: False. + - legacy_patch_emb (bool): If True, applies 3D convolutional layers for video inputs, otherwise, use Linear! The legacy model is for backward compatibility. Default: True. + The output shape of the module depends on the `keep_spatio` flag. If `keep_spatio`=True, the output retains the spatial dimensions. + Otherwise, the spatial dimensions are flattened into a single dimension. + """ + + def __init__( + self, + spatial_patch_size, + temporal_patch_size, + in_channels=3, + out_channels=768, + bias=True, + keep_spatio=False, + legacy_patch_emb: bool = True, + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + assert keep_spatio, "Only support keep_spatio=True" + self.keep_spatio = keep_spatio + self.legacy_patch_emb = legacy_patch_emb + + if legacy_patch_emb: + self.proj = nn.Conv3d( + in_channels, + out_channels, + kernel_size=(temporal_patch_size, spatial_patch_size, spatial_patch_size), + stride=(temporal_patch_size, spatial_patch_size, spatial_patch_size), + bias=bias, + ) + self.out = Rearrange("b c t h w -> b t h w c") + else: + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + nn.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias + ), + ) + self.out = nn.Identity() + + def forward(self, x): + """ + Forward pass of the PatchEmbed module. + + Parameters: + - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where + B is the batch size, + C is the number of channels, + T is the temporal dimension, + H is the height, and + W is the width of the input. + + Returns: + - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. + """ + assert x.dim() == 5 + _, _, T, H, W = x.shape + assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 + assert T % self.temporal_patch_size == 0 + x = self.proj(x) + return self.out(x) + + +class ExtraTokenPatchEmbed(PatchEmbed): + def __init__(self, *args, out_channels: int = 768, keep_spatio: bool = False, **kwargs): + assert keep_spatio, "ExtraTokenPatchEmbed only supports keep_spatio=True" + super().__init__(*args, out_channels=out_channels, keep_spatio=keep_spatio, **kwargs) + self.temporal_token = nn.Parameter(torch.randn(1, 1, 1, 1, out_channels)) + self.spatial_token = nn.Parameter(torch.randn(1, 1, 1, 1, out_channels)) + + def forward(self, x): + x_B_T_H_W_C = super().forward(x) + B, T, H, W, C = x_B_T_H_W_C.shape + x_B_T_H_W_C = torch.cat( + [ + x_B_T_H_W_C, + self.temporal_token.repeat(B, 1, H, W, 1), + ], + dim=1, + ) + x_B_T_H_W_C = torch.cat( + [ + x_B_T_H_W_C, + self.spatial_token.repeat(B, T, H, 1, 1), + ], + dim=3, + ) + return x_B_T_H_W_C + + +class ExpertChoiceMoEGate(nn.Module): + """ + ExpertChoiceMoEGate determines which tokens go + to which experts (and how much to weigh each expert). + + Args: + hidden_size (int): Dimensionality of input features. + num_experts (int): Number of experts (E). + capacity (int): Capacity (number of tokens) each expert can process (C). + """ + + def __init__( + self, + hidden_size: int, + num_experts: int, + capacity: int, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_experts = num_experts + self.capacity = capacity + + self.router = nn.Parameter(torch.empty((self.num_experts, self.hidden_size))) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.kaiming_uniform_(self.router) + + def forward(self, x: torch.Tensor): + """ + Args: + x (Tensor): Input of shape (B, S, D) + Returns: + gating (Tensor): Gating weights of shape (B, E, C), + where E = num_experts, C = capacity (top-k). + dispatch (Tensor): Dispatch mask of shape (B, E, C, S). + index (Tensor): Indices of top-k tokens for each expert, + shape (B, E, C). + """ + B, S, D = x.shape + E, C = self.num_experts, self.capacity + + # token-expert affinity scores + logits = torch.einsum("bsd,de->bse", x, self.router) + affinity = torch.nn.functional.softmax(logits, dim=-1) # (B, S, E) + + # gather topk tokens for each expert + affinity_t = affinity.transpose(1, 2) # (B, E, S) + + # select top-k tokens for each expert + gating, index = torch.topk(affinity_t, k=C, dim=-1) # (B, E, C) + + # one-hot dispatch mask + dispatch = torch.nn.functional.one_hot(index, num_classes=S).float() # (B, E, C, S) + + return gating, dispatch, index + + +class ExpertChoiceMoELayer(nn.Module): + """ + ExpertChoiceMoELayer uses the ExpertChoiceMoEGate to route tokens + to experts, process them, and then combine the outputs. + + Args: + gate_hidden_size (int): Dimensionality of input features. + ffn_hidden_size (int): Dimension of hidden layer in each expert feedforward (e.g., GPT2FeedForward). + num_experts (int): Number of experts (E). + capacity (int): Capacity (number of tokens) each expert can process (C). + expert_cls (nn.Module): The class to instantiate each expert. Defaults to GPT2FeedForward. + expert_kwargs (dict): Extra kwargs to pass to each expert class. + """ + + def __init__( + self, + gate_hidden_size: int, + ffn_hidden_size: int, + num_experts: int, + capacity: int, + expert_class: nn.Module = GPT2FeedForward, + expert_kwargs=None, + ): + super().__init__() + if not expert_kwargs: + expert_kwargs = {} + + self.gate_hidden_size = gate_hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.num_experts = num_experts + self.capacity = capacity + + self.gate = ExpertChoiceMoEGate(gate_hidden_size, num_experts, capacity) + + self.experts = nn.ModuleList( + [expert_class(gate_hidden_size, ffn_hidden_size, **expert_kwargs) for _ in range(num_experts)] + ) + + def forward(self, x: torch.Tensor): + """ + Args: + x (Tensor): Input of shape (B, S, D). + + Returns: + x_out (Tensor): Output of shape (B, S, D), after dispatching tokens + to experts and combining their outputs. + """ + B, S, D = x.shape + E, C = self.num_experts, self.capacity + + # gating: (B, E, C) + # dispatch: (B, E, C, S) + gating, dispatch, index = self.gate(x) + + # collect input tokens for each expert + x_in = torch.einsum("becs,bsd->becd", dispatch, x) + + # process through each expert + expert_outputs = [self.experts[e](x_in[:, e]) for e in range(E)] + + x_e = torch.stack(expert_outputs, dim=1) # (B, E, C, D) + + # gating: (B, E, C), dispatch: (B, E, C, S), x_e: (B, E, C, d) + # x_out: (B, S, D) + # each token is placed back to its location with weighting + x_out = torch.einsum("becs,bec,becd->bsd", dispatch, gating, x_e) + + return x_out + + +class FinalLayer(nn.Module): + """ + The final layer of video DiT. + """ + + def __init__( + self, + hidden_size, + spatial_patch_size, + temporal_patch_size, + out_channels, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + ): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False + ) + self.hidden_size = hidden_size + self.n_adaln_chunks = 2 + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False) + ) + + self.sequence_parallel = getattr(parallel_state, "sequence_parallel", False) + + def forward( + self, + x_BT_HW_D, + emb_B_D, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_3D is not None + shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk( + 2, dim=1 + ) + else: + shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1) + + B = emb_B_D.shape[0] + T = x_BT_HW_D.shape[0] // B + shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T) + x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D) + if self.sequence_parallel: + x_T_B_HW_D = rearrange(x_BT_HW_D, "(b t) hw d -> t b hw d", b=B, t=T) + x_T_B_HW_D = gather_along_first_dim(x_T_B_HW_D, parallel_state.get_tensor_model_parallel_group()) + x_BT_HW_D = rearrange(x_T_B_HW_D, "t b hw d -> (b t) hw d", b=B) + + x_BT_HW_D = self.linear(x_BT_HW_D) + return x_BT_HW_D + + def forward_with_memory_save( + self, + x_BT_HW_D_before_gate: torch.Tensor, + x_BT_HW_D_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_3D is not None + shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk( + 2, dim=1 + ) + else: + shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1) + + B = emb_B_D.shape[0] + T = x_BT_HW_D_before_gate.shape[0] // B + shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T) + gate_BT_1_D = repeat(gate_L_B_D, "1 b d -> (b t) 1 d", t=T) + + def _fn(_x_before_gate, _x_skip): + previous_block_out = _x_skip + gate_BT_1_D * _x_before_gate + _x = modulate(self.norm_final(previous_block_out), shift_BT_D, scale_BT_D) + return self.linear(_x) + + return torch.utils.checkpoint.checkpoint(_fn, x_BT_HW_D_before_gate, x_BT_HW_D_skip, use_reentrant=False) + + +class VideoAttn(nn.Module): + """ + Implements video attention with optional cross-attention capabilities. + + This module supports both self-attention within the video frames and cross-attention + with an external context. It's designed to work with flattened spatial dimensions + to accommodate for video input. + + Attributes: + x_dim (int): Dimensionality of the input feature vectors. + context_dim (Optional[int]): Dimensionality of the external context features. + If None, the attention does not utilize external context. + num_heads (int): Number of attention heads. + bias (bool): If true, bias is added to the query, key, value projections. + x_format (str): The shape format of x tenosor. + n_views (int): Extra parameter used in multi-view diffusion model. It indicated total number of view we model together. + """ + + def __init__( + self, + x_dim: int, + context_dim: Optional[int], + num_heads: int, + bias: bool = False, + x_format: str = "BTHWD", + n_views: int = 1, + ) -> None: + super().__init__() + self.n_views = n_views + self.x_format = x_format + if self.x_format == "BTHWD": + qkv_format = "bshd" + elif self.x_format == "THWBD": + qkv_format = "sbhd" + else: + raise NotImplementedError(f"Unsupported x_format: {self.x_format}") + + self.attn = Attention( + x_dim, + context_dim, + num_heads, + x_dim // num_heads, + qkv_bias=bias, + qkv_norm="RRI", + out_bias=bias, + qkv_format=qkv_format, + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for video attention. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data. + context (Tensor): Context tensor of shape (B, M, D) or (M, B, D), where M is the sequence length of the context. + crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms. + rope_emb_L_1_1_D (Optional[Tensor]): Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. transformer_engine format + + Returns: + Tensor: The output tensor with applied attention, maintaining the input shape. + """ + + if self.x_format == "BTHWD": + if context is not None and self.n_views > 1: + x_B_T_H_W_D = rearrange(x, "b (v t) h w d -> (v b) t h w d", v=self.n_views) + context_B_M_D = rearrange(context, "b (v m) d -> (v b) m d", v=self.n_views) + else: + x_B_T_H_W_D = x + context_B_M_D = context + B, T, H, W, D = x_B_T_H_W_D.shape + x_B_THW_D = rearrange(x_B_T_H_W_D, "b t h w d -> b (t h w) d") + x_B_THW_D = self.attn(x_B_THW_D, context_B_M_D, crossattn_mask, rope_emb=rope_emb_L_1_1_D) + + # reshape it back to video format + x_B_T_H_W_D = rearrange(x_B_THW_D, "b (t h w) d -> b t h w d", h=H, w=W) + if context is not None and self.n_views > 1: + x_B_T_H_W_D = rearrange(x_B_T_H_W_D, "(v b) t h w d -> b (v t) h w d", v=self.n_views) + return x_B_T_H_W_D + elif self.x_format == "THWBD": + if context is not None and self.n_views > 1: + x_T_H_W_B_D = rearrange(x, "(v t) h w b d -> t h w (v b) d", v=self.n_views) + context_M_B_D = rearrange(context, "(v m) b d -> m (v b) d", v=self.n_views) + else: + x_T_H_W_B_D = x + context_M_B_D = context + T, H, W, B, D = x_T_H_W_B_D.shape + x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d") + x_THW_B_D = self.attn( + x_THW_B_D, + context_M_B_D, + crossattn_mask, + rope_emb=rope_emb_L_1_1_D, + ) + x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) + if context is not None and self.n_views > 1: + x_T_H_W_B_D = rearrange(x_T_H_W_B_D, "t h w (v b) d -> (v t) h w b d", v=self.n_views) + return x_T_H_W_B_D + else: + raise NotImplementedError(f"Unsupported x_format: {self.x_format}") + + +def checkpoint_norm_state(norm_state, x, scale, shift): + normalized = norm_state(x) + return normalized * (1 + scale) + shift + + +class DITBuildingBlock(nn.Module): + """ + DIT Building Block for constructing various types of attention or MLP blocks dynamically based on a specified block type. + + This class instantiates different types of buildig block / attn and MLP based on config, and applies crossponding forward pass during training. + + Attributes: + block_type (str): Type of block to be used ('spatial_sa', 'temporal_sa', 'cross_attn', 'full_attn', 'mlp'). + x_dim (int): Dimensionality of the input features. + context_dim (Optional[int]): Dimensionality of the external context, required for cross attention blocks. + num_heads (int): Number of attention heads. + mlp_ratio (float): Multiplier for the dimensionality of the MLP hidden layer compared to input. + spatial_win_size (int): Window size for spatial self-attention. + temporal_win_size (int): Window size for temporal self-attention. + bias (bool): Whether to include bias in attention and MLP computations. + mlp_dropout (float): Dropout rate for MLP blocks. + n_views (int): Extra parameter used in multi-view diffusion model. It indicated total number of view we model together. + """ + + def __init__( + self, + block_type: str, + x_dim: int, + context_dim: Optional[int], + num_heads: int, + mlp_ratio: float = 4.0, + window_sizes: list = [], + spatial_win_size: int = 1, + temporal_win_size: int = 1, + bias: bool = False, + mlp_dropout: float = 0.0, + x_format: str = "BTHWD", + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + n_views: int = 1, + ) -> None: + block_type = block_type.lower() + + super().__init__() + self.x_format = x_format + if block_type in ["cross_attn", "ca"]: + self.block = VideoAttn( + x_dim, + context_dim, + num_heads, + bias=bias, + x_format=self.x_format, + n_views=n_views, + ) + elif block_type in ["full_attn", "fa"]: + self.block = VideoAttn(x_dim, None, num_heads, bias=bias, x_format=self.x_format) + elif block_type in ["mlp", "ff"]: + self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias) + else: + raise ValueError(f"Unknown block type: {block_type}") + + self.block_type = block_type + self.use_adaln_lora = use_adaln_lora + + self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) + self.n_adaln_chunks = 3 + if use_adaln_lora: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(x_dim, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False), + ) + else: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False)) + + def forward_with_attn_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + del crossattn_mask + assert isinstance(self.block, VideoAttn), "only support VideoAttn impl" + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( + shift_B_D.unsqueeze(0), + scale_B_D.unsqueeze(0), + gate_B_D.unsqueeze(0), + ) + + def _fn(_x_before_gate, _x_skip, _context): + previous_block_out = _x_skip + gate_L_B_D * _x_before_gate + if extra_per_block_pos_emb is not None: + previous_block_out = previous_block_out + extra_per_block_pos_emb + _normalized_x = self.norm_state(previous_block_out) + normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D + # context = normalized_x if _context is None else _context + context = normalized_x if self.block.attn.is_selfattn else _context + return ( + self.block.attn.to_q[0](normalized_x), + self.block.attn.to_k[0](context), + self.block.attn.to_v[0](context), + previous_block_out, + ) + + q, k, v, previous_block_out = torch.utils.checkpoint.checkpoint( + _fn, x_before_gate, x_skip, crossattn_emb, use_reentrant=False + ) + + def attn_fn(_q, _k, _v): + q, k, v = map( + lambda t: rearrange( + t, + "b ... (n c) -> b ... n c", + n=self.block.attn.heads // self.block.attn.tp_size, + c=self.block.attn.dim_head, + ), + (_q, _k, _v), + ) + q = self.block.attn.to_q[1](q) + k = self.block.attn.to_k[1](k) + v = self.block.attn.to_v[1](v) + if self.block.attn.is_selfattn and rope_emb_L_1_1_D is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) + k = apply_rotary_pos_emb(k, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) + + if self.block.attn.is_selfattn: + return q, k, v + + seq_dim = self.block.attn.qkv_format.index("s") + assert ( + q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 + ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." + return self.block.attn.attn_op( + q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None + ) # [B, Mq, H, V] + + assert self.block.attn.backend == "transformer_engine", "Only support transformer_engine backend for now." + + if self.block.attn.is_selfattn: + q, k, v = torch.utils.checkpoint.checkpoint(attn_fn, q, k, v, use_reentrant=False) + seq_dim = self.block.attn.qkv_format.index("s") + assert ( + q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 + ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." + softmax_attn_output = self.block.attn.attn_op( + q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None + ) # [B, Mq, H, V] + else: + softmax_attn_output = torch.utils.checkpoint.checkpoint(attn_fn, q, k, v, use_reentrant=False) + attn_out = self.block.attn.to_out(softmax_attn_output) + return _gate_L_B_D, attn_out, previous_block_out + + def forward_with_x_attn_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + del crossattn_mask + assert isinstance(self.block, VideoAttn) + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( + shift_B_D.unsqueeze(0), + scale_B_D.unsqueeze(0), + gate_B_D.unsqueeze(0), + ) + + def _fn(_x_before_gate, _x_skip, _context): + previous_block_out = _x_skip + gate_L_B_D * _x_before_gate + if extra_per_block_pos_emb is not None: + previous_block_out = previous_block_out + extra_per_block_pos_emb + _normalized_x = self.norm_state(previous_block_out) + normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D + # context = normalized_x if _context is None else _context + context = normalized_x if self.block.attn.is_selfattn else _context + return ( + self.block.attn.to_q[0](normalized_x), + self.block.attn.to_k[0](context), + self.block.attn.to_v[0](context), + previous_block_out, + ) + + q, k, v, previous_block_out = torch.utils.checkpoint.checkpoint( + _fn, x_before_gate, x_skip, crossattn_emb, use_reentrant=False + ) + + def x_attn_fn(_q, _k, _v): + q, k, v = map( + lambda t: rearrange( + t, + "b ... (n c) -> b ... n c", + n=self.block.attn.heads // self.block.attn.tp_size, + c=self.block.attn.dim_head, + ), + (_q, _k, _v), + ) + q = self.block.attn.to_q[1](q) + k = self.block.attn.to_k[1](k) + v = self.block.attn.to_v[1](v) + if self.block.attn.is_selfattn and rope_emb_L_1_1_D is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) + k = apply_rotary_pos_emb(k, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) + + seq_dim = self.block.attn.qkv_format.index("s") + assert ( + q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 + ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." + softmax_attn_output = self.block.attn.attn_op( + q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None + ) # [B, Mq, H, V] + return self.block.attn.to_out(softmax_attn_output) + + assert self.block.attn.backend == "transformer_engine", "Only support transformer_engine backend for now." + + attn_out = torch.utils.checkpoint.checkpoint(x_attn_fn, q, k, v, use_reentrant=False) + return _gate_L_B_D, attn_out, previous_block_out + + def forward_with_ffn_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + del crossattn_emb, crossattn_mask, rope_emb_L_1_1_D + assert isinstance(self.block, GPT2FeedForward) + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( + shift_B_D.unsqueeze(0), + scale_B_D.unsqueeze(0), + gate_B_D.unsqueeze(0), + ) + + def _fn(_x_before_gate, _x_skip): + previous_block_out = _x_skip + gate_L_B_D * _x_before_gate + if extra_per_block_pos_emb is not None: + previous_block_out = previous_block_out + extra_per_block_pos_emb + _normalized_x = self.norm_state(previous_block_out) + normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D + + assert self.block.dropout.p == 0.0, "we skip dropout to save memory" + + return self.block.layer1(normalized_x), previous_block_out + + intermediate_output, previous_block_out = torch.utils.checkpoint.checkpoint( + _fn, x_before_gate, x_skip, use_reentrant=False + ) + + def _fn2(_x): + _x = self.block.activation(_x) + return self.block.layer2(_x) + + return ( + _gate_L_B_D, + torch.utils.checkpoint.checkpoint(_fn2, intermediate_output, use_reentrant=False), + previous_block_out, + ) + + def forward_with_ffn_memory_save_upgrade( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + del crossattn_emb, crossattn_mask, rope_emb_L_1_1_D + assert isinstance(self.block, GPT2FeedForward) + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( + shift_B_D.unsqueeze(0), + scale_B_D.unsqueeze(0), + gate_B_D.unsqueeze(0), + ) + + def _fn2(_x): + _x = self.block.activation(_x) + return self.block.layer2(_x) + + def _fn(_x_before_gate, _x_skip): + previous_block_out = _x_skip + gate_L_B_D * _x_before_gate + if extra_per_block_pos_emb is not None: + previous_block_out = previous_block_out + extra_per_block_pos_emb + _normalized_x = self.norm_state(previous_block_out) + normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D + + assert self.block.dropout.p == 0.0, "we skip dropout to save memory" + + return _fn2(self.block.layer1(normalized_x)), previous_block_out + + output, previous_block_out = torch.utils.checkpoint.checkpoint(_fn, x_before_gate, x_skip, use_reentrant=False) + + return ( + _gate_L_B_D, + output, + previous_block_out, + ) + + def forward_with_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + if isinstance(self.block, VideoAttn): + if self.block.attn.is_selfattn: + fn = self.forward_with_attn_memory_save + else: + fn = self.forward_with_x_attn_memory_save + else: + # fn = self.forward_with_ffn_memory_save + fn = self.forward_with_ffn_memory_save_upgrade + return fn( + x_before_gate, + x_skip, + gate_L_B_D, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_per_block_pos_emb, + ) + + def forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for dynamically configured blocks with adaptive normalization. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D). + emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation. + crossattn_emb (Tensor): Tensor for cross-attention blocks. + crossattn_mask (Optional[Tensor]): Optional mask for cross-attention. + rope_emb_L_1_1_D (Optional[Tensor]): Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. transformer_engine format + + Returns: + Tensor: The output tensor after processing through the configured block and adaptive normalization. + """ + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + if self.x_format == "BTHWD": + shift_B_1_1_1_D, scale_B_1_1_1_D, gate_B_1_1_1_D = ( + shift_B_D.unsqueeze(1).unsqueeze(2).unsqueeze(3), + scale_B_D.unsqueeze(1).unsqueeze(2).unsqueeze(3), + gate_B_D.unsqueeze(1).unsqueeze(2).unsqueeze(3), + ) + if self.block_type in ["spatial_sa", "temporal_sa", "window_attn", "ssa", "tsa", "wa"]: + x = x + gate_B_1_1_1_D * self.block( + self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + elif self.block_type in ["full_attn", "fa"]: + x = x + gate_B_1_1_1_D * self.block( + self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, + context=None, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + elif self.block_type in ["cross_attn", "ca"]: + x = x + gate_B_1_1_1_D * self.block( + self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, + crossattn_emb, + crossattn_mask, + ) + elif self.block_type in ["mlp", "ff"]: + x = x + gate_B_1_1_1_D * self.block( + self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, + ) + else: + raise ValueError(f"Unknown block type: {self.block_type}") + elif self.x_format == "THWBD": + shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = ( + shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + ) + + if self.block_type in ["mlp", "ff"]: + x = x + gate_1_1_1_B_D * self.block( + torch.utils.checkpoint.checkpoint( + checkpoint_norm_state, self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D, use_reentrant=False + ), + ) + elif self.block_type in ["full_attn", "fa"]: + x = x + gate_1_1_1_B_D * self.block( + torch.utils.checkpoint.checkpoint( + checkpoint_norm_state, self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D, use_reentrant=False + ), + context=None, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + elif self.block_type in ["cross_attn", "ca"]: + x = x + gate_1_1_1_B_D * self.block( + torch.utils.checkpoint.checkpoint( + checkpoint_norm_state, self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D, use_reentrant=False + ), + context=crossattn_emb, + crossattn_mask=crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + else: + raise ValueError(f"Unknown block type: {self.block_type}") + else: + raise NotImplementedError(f"Unsupported x_format: {self.x_format}") + return x + + +class GeneralDITTransformerBlock(nn.Module): + """ + This class is a wrapper for a list of DITBuildingBlock. + It's not essential, refactor it if needed. + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + block_config: str, + mlp_ratio: float = 4.0, + window_sizes: list = [], + spatial_attn_win_size: int = 1, + temporal_attn_win_size: int = 1, + use_checkpoint: bool = False, + x_format: str = "BTHWD", + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + n_views: int = 1, + ): + super().__init__() + self.blocks = nn.ModuleList() + self.x_format = x_format + for block_type in block_config.split("-"): + self.blocks.append( + DITBuildingBlock( + block_type, + x_dim, + context_dim, + num_heads, + mlp_ratio, + window_sizes, + spatial_attn_win_size, + temporal_attn_win_size, + x_format=self.x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + n_views=n_views, + ) + ) + self.use_checkpoint = use_checkpoint + + def forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint( + self._forward, + x, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_per_block_pos_emb, + ) + else: + return self._forward( + x, emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, extra_per_block_pos_emb + ) + + def _forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if extra_per_block_pos_emb is not None: + x = x + extra_per_block_pos_emb + for block in self.blocks: + x = block( + x, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + return x + + def set_memory_save(self, mode: bool = True): + # (qsh) to make fsdp happy! + #! IMPORTANT! + if mode: + self.forward = self.forward_with_memory_save + for block in self.blocks: + block.forward = block.forward_with_memory_save + else: + raise NotImplementedError("Not implemented yet.") + + def forward_with_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + for block in self.blocks: + gate_L_B_D, x_before_gate, x_skip = block.forward( + x_before_gate, + x_skip, + gate_L_B_D, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_per_block_pos_emb, + ) + extra_per_block_pos_emb = None + return gate_L_B_D, x_before_gate, x_skip \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/training/modules/pretrained_vae.py b/cosmos_transfer1/diffusion/training/modules/pretrained_vae.py new file mode 100644 index 00000000..21769467 --- /dev/null +++ b/cosmos_transfer1/diffusion/training/modules/pretrained_vae.py @@ -0,0 +1,738 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import ABC, abstractmethod + +import torch +import torch.nn.functional as F +from torch.nn.modules import Module +from einops import rearrange + +from cosmos_transfer1.utils.distributed import rank0_first + + +class BaseVAE(torch.nn.Module, ABC): + """ + Abstract base class for a Variational Autoencoder (VAE). + + All subclasses should implement the methods to define the behavior for encoding + and decoding, along with specifying the latent channel size. + """ + + def __init__(self, channel: int = 3, name: str = "vae"): + super().__init__() + self.channel = channel + self.name = name + + @property + def latent_ch(self) -> int: + """ + Returns the number of latent channels in the VAE. + """ + return self.channel + + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encodes the input tensor into a latent representation. + + Args: + - state (torch.Tensor): The input tensor to encode. + + Returns: + - torch.Tensor: The encoded latent tensor. + """ + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes the latent representation back to the original space. + + Args: + - latent (torch.Tensor): The latent tensor to decode. + + Returns: + - torch.Tensor: The decoded tensor. + """ + pass + + @property + def spatial_compression_factor(self) -> int: + """ + Returns the spatial reduction factor for the VAE. + """ + raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.") + + + +class VideoTokenizerInterface(ABC): + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + pass + + @abstractmethod + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + pass + + @abstractmethod + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + pass + + @property + @abstractmethod + def spatial_compression_factor(self): + pass + + @property + @abstractmethod + def temporal_compression_factor(self): + pass + + @property + @abstractmethod + def spatial_resolution(self): + pass + + @property + @abstractmethod + def pixel_chunk_duration(self): + pass + + @property + @abstractmethod + def latent_chunk_duration(self): + pass + + @property + def is_chunk_overlap(self): + return False + + +class BasePretrainedImageVAE(BaseVAE): + """ + A base class for pretrained Variational Autoencoder (VAE) that loads mean and standard deviation values + from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + Derived classes should load pre-trained encoder and decoder components from a remote store + + Attributes: + latent_mean (Tensor): The mean used for normalizing the latent representation. + latent_std (Tensor): The standard deviation used for normalizing the latent representation. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. + latent_ch (int, optional): Number of latent channels (default is 16). + is_image (bool, optional): Flag to indicate whether the output is an image (default is True). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + """ + + def __init__( + self, + name: str, + mean_std_fp: str, + latent_ch: int = 16, + is_image: bool = True, + is_bf16: bool = True, + ) -> None: + super().__init__(latent_ch, name) + dtype = torch.bfloat16 if is_bf16 else torch.float32 + self.dtype = dtype + self.is_image = is_image + self.mean_std_fp = mean_std_fp + self.name = name + + self.backend_args = None + + self.register_mean_std(mean_std_fp) + + def register_mean_std(self, mean_std_fp: str) -> None: + latent_mean, latent_std = torch.load(mean_std_fp, map_location="cuda", weights_only=True) + target_shape = [1, self.latent_ch, 1, 1] if self.is_image else [1, self.latent_ch, 1, 1, 1] + self.register_buffer( + "latent_mean", + latent_mean.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + self.register_buffer( + "latent_std", + latent_std.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encode the input state to latent space; also handle the dtype conversion, mean and std scaling + """ + in_dtype = state.dtype + latent_mean = self.latent_mean.to(in_dtype) + latent_std = self.latent_std.to(in_dtype) + encoded_state = self.encoder(state.to(self.dtype)) + if isinstance(encoded_state, torch.Tensor): + pass + elif isinstance(encoded_state, tuple): + assert isinstance(encoded_state[0], torch.Tensor) + encoded_state = encoded_state[0] + else: + raise ValueError("Invalid type of encoded state") + return (encoded_state.to(in_dtype) - latent_mean) / latent_std + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decode the input latent to state; also handle the dtype conversion, mean and std scaling + """ + in_dtype = latent.dtype + latent = latent * self.latent_std.to(in_dtype) + self.latent_mean.to(in_dtype) + return self.decoder(latent.to(self.dtype)).to(in_dtype) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.decoder.to(self.dtype) + self.encoder.to(self.dtype) + + +class JITVAE(BasePretrainedImageVAE): + """ + A JIT compiled Variational Autoencoder (VAE) that loads pre-trained encoder + and decoder components from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + + Attributes: + encoder (Module): The JIT compiled encoder loaded from storage. + decoder (Module): The JIT compiled decoder loaded from storage. + latent_mean (Tensor): The mean used for normalizing the latent representation. + latent_std (Tensor): The standard deviation used for normalizing the latent representation. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + enc_fp (str): File path to the encoder's JIT file on the remote store. + dec_fp (str): File path to the decoder's JIT file on the remote store. + name (str): Name of the model, used for differentiating cache file paths. + mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. + latent_ch (int, optional): Number of latent channels (default is 16). + is_image (bool, optional): Flag to indicate whether the output is an image (default is True). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + name: str, + mean_std_fp: str, + latent_ch: int = 16, + is_image: bool = True, + is_bf16: bool = True, + ): + super().__init__(name, mean_std_fp, latent_ch, is_image, is_bf16) + self.load_encoder(enc_fp) + self.load_decoder(dec_fp) + + def load_encoder(self, enc_fp: str) -> None: + """ + Load the encoder from the remote store. + + Args: + - enc_fp (str): File path to the encoder's JIT file on the remote store. + """ + self.encoder = torch.jit.load(enc_fp, map_location="cuda") + self.encoder.eval() + for param in self.encoder.parameters(): + param.requires_grad = False + self.encoder.to(self.dtype) + + def load_decoder(self, dec_fp: str) -> None: + """ + Load the decoder from the remote store. + + Args: + - dec_fp (str): File path to the decoder's JIT file on the remote store. + """ + self.decoder = torch.jit.load(dec_fp, map_location="cuda") + self.decoder.eval() + for param in self.decoder.parameters(): + param.requires_grad = False + self.decoder.to(self.dtype) + + +# class StateDictVAE(BasePretrainedImageVAE): +# """ +# A Variational Autoencoder (VAE) that loads pre-trained weights into +# provided encoder and decoder components from a remote store, handles data type conversions, +# and normalization using provided mean and standard deviation values for latent space representation. + +# Attributes: +# encoder (Module): The encoder with weights loaded from storage. +# decoder (Module): The decoder with weights loaded from storage. +# latent_mean (Tensor): The mean used for normalizing the latent representation. +# latent_std (Tensor): The standard deviation used for normalizing the latent representation. +# dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + +# Args: +# enc_fp (str): File path to the encoder's JIT file on the remote store. +# dec_fp (str): File path to the decoder's JIT file on the remote store. +# vae (Module): Instance of VAE with not loaded weights +# name (str): Name of the model, used for differentiating cache file paths. +# mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. +# latent_ch (int, optional): Number of latent channels (default is 16). +# is_image (bool, optional): Flag to indicate whether the output is an image (default is True). +# is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). +# """ + +# def __init__( +# self, +# enc_fp: str, +# dec_fp: str, +# vae: torch.nn.Module, +# name: str, +# mean_std_fp: str, +# latent_ch: int = 16, +# is_image: bool = True, +# is_bf16: bool = True, +# ): +# super().__init__(name, mean_std_fp, latent_ch, is_image, is_bf16) + +# self.load_encoder_and_decoder(enc_fp, dec_fp, vae) + +# def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, vae: torch.nn.Module) -> None: +# """ +# Load the encoder from the remote store. + +# Args: +# - vae_fp (str): File path to the vae's state dict file on the remote store. +# - vae (str): VAE module into which weights will be loaded. +# """ +# state_dict_enc = load_from_s3_with_cache( +# enc_fp, +# f"vae/{self.name}_enc.jit", +# easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, +# backend_args=self.backend_args, +# ) + +# state_dict_dec = load_from_s3_with_cache( +# dec_fp, +# f"vae/{self.name}_dec.jit", +# easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, +# backend_args=self.backend_args, +# ) + +# jit_weights_state_dict = state_dict_enc.state_dict() | state_dict_dec.state_dict() +# jit_weights_state_dict = { +# k: v +# for k, v in jit_weights_state_dict.items() +# # Global variables captured by JIT +# if k +# not in ( +# "encoder.patcher.wavelets", +# "encoder.patcher._arange", +# "decoder.unpatcher.wavelets", +# "decoder.unpatcher._arange", +# ) +# } + +# vae.load_state_dict(jit_weights_state_dict) +# vae.eval() +# for param in vae.parameters(): +# param.requires_grad = False +# vae.to(self.dtype) + +# self.vae = vae +# self.encoder = self.vae.encode +# self.decoder = self.vae.decode + +# def reset_dtype(self, *args, **kwargs): +# """ +# Resets the data type of the encoder and decoder to the model's default data type. + +# Args: +# *args, **kwargs: Unused, present to allow flexibility in method calls. +# """ +# del args, kwargs +# self.vae.to(self.dtype) + + +class SDVAE(BaseVAE): + def __init__(self, batch_size=16, count_std: bool = False, is_downsample: bool = True) -> None: + super().__init__(channel=4, name="sd_vae") + self.dtype = torch.bfloat16 + self.register_buffer( + "scale", + torch.tensor([4.17, 4.62, 3.71, 3.28], dtype=self.dtype).reciprocal().reshape(1, -1, 1, 1), + persistent=False, + ) + self.register_buffer( + "bias", + -1.0 * torch.tensor([5.81, 3.25, 0.12, -2.15], dtype=self.dtype).reshape(1, -1, 1, 1) * self.scale, + persistent=False, + ) + self.batch_size = batch_size + self.count_std = count_std + self.is_downsample = is_downsample + self.load_vae() + self.reset_dtype() + + def reset_dtype(self, *args, **kwargs): + del args, kwargs + self.vae.to(self.dtype) + + @rank0_first + def load_vae(self) -> None: + os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" + os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" + import diffusers + + vae_name = "stabilityai/sd-vae-ft-mse" + try: + vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, local_files_only=True) + except: # noqa: E722 + # Could not load the model from cache; try without local_files_only. + vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name) + self.vae = vae.eval().requires_grad_(False) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + state : pixel range [-1, 1] + """ + if self.is_downsample: + _h, _w = state.shape[-2:] + state = F.interpolate(state, size=(_h // 2, _w // 2), mode="bilinear", align_corners=False) + in_dtype = state.dtype + state = state.to(self.dtype) + state = (state + 1.0) / 2.0 + latent_dist = self.vae.encode(state)["latent_dist"] + mean, std = latent_dist.mean, latent_dist.std + if self.count_std: + latent = mean + torch.randn_like(mean) * std + else: + latent = mean + latent = latent * self.scale + latent = latent + self.bias + return latent.to(in_dtype) + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + in_dtype = latent.dtype + latent = latent.to(self.dtype) + latent = latent - self.bias + latent = latent / self.scale + latent = torch.cat([self.vae.decode(batch)["sample"] for batch in latent.split(self.batch_size)]) + if self.is_downsample: + _h, _w = latent.shape[-2:] + latent = F.interpolate(latent, size=(_h * 2, _w * 2), mode="bilinear", align_corners=False) + return latent.to(in_dtype) * 2 - 1.0 + + @property + def spatial_compression_factor(self) -> int: + return 8 + + +class BasePretrainedVideoTokenizer(ABC): + """ + Base class for a pretrained video tokenizer that handles chunking of video data for efficient processing. + + Args: + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + temporal_compress_factor (int): The factor by which the video data is temporally compressed during processing. + max_enc_batch_size (int): The maximum batch size to process in one go during encoding to avoid memory overflow. + max_dec_batch_size (int): The maximum batch size to process in one go during decoding to avoid memory overflow. + + The class introduces parameters for managing temporal chunks (`pixel_chunk_duration` and `temporal_compress_factor`) + which define how video data is subdivided and compressed during the encoding and decoding processes. The + `max_enc_batch_size` and `max_dec_batch_size` parameters allow processing in smaller batches to handle memory + constraints. + """ + + def __init__( + self, + pixel_chunk_duration: int = 17, + temporal_compress_factor: int = 8, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + ): + self._pixel_chunk_duration = pixel_chunk_duration + self._temporal_compress_factor = temporal_compress_factor + self.max_enc_batch_size = max_enc_batch_size + self.max_dec_batch_size = max_dec_batch_size + + def register_mean_std(self, mean_std_fp: str) -> None: + latent_mean, latent_std = torch.load(mean_std_fp, map_location="cuda", weights_only=True) + latent_mean = latent_mean.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] + latent_std = latent_std.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] + + target_shape = [1, self.latent_ch, self.latent_chunk_duration, 1, 1] + + self.register_buffer( + "latent_mean", + latent_mean.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + self.register_buffer( + "latent_std", + latent_std.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + + def transform_encode_state_shape(self, state: torch.Tensor) -> torch.Tensor: + """ + Rearranges the input state tensor to the required shape for encoding video data. Mainly for chunk based encoding + """ + B, C, T, H, W = state.shape + assert ( + T % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.pixel_chunk_duration}" + return rearrange(state, "b c (n t) h w -> (b n) c t h w", t=self.pixel_chunk_duration) + + def transform_decode_state_shape(self, latent: torch.Tensor) -> None: + B, _, T, _, _ = latent.shape + assert ( + T % self.latent_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.latent_chunk_duration}" + return rearrange(latent, "b c (n t) h w -> (b n) c t h w", t=self.latent_chunk_duration) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + if self._temporal_compress_factor == 1: + _, _, origin_T, _, _ = state.shape + state = rearrange(state, "b c t h w -> (b t) c 1 h w") + B, C, T, H, W = state.shape + state = self.transform_encode_state_shape(state) + # use max_enc_batch_size to avoid OOM + if state.shape[0] > self.max_enc_batch_size: + latent = [] + for i in range(0, state.shape[0], self.max_enc_batch_size): + latent.append(super().encode(state[i : i + self.max_enc_batch_size])) + latent = torch.cat(latent, dim=0) + else: + latent = super().encode(state) + + latent = rearrange(latent, "(b n) c t h w -> b c (n t) h w", b=B) + if self._temporal_compress_factor == 1: + latent = rearrange(latent, "(b t) c 1 h w -> b c t h w", t=origin_T) + return latent + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes a batch of latent representations into video frames by applying temporal chunking. Similar to encode, + it handles video data by processing smaller temporal chunks to reconstruct the original video dimensions. + + It can also decode single frame image data. + + Args: + latent (torch.Tensor): The latent space tensor containing encoded video data. + + Returns: + torch.Tensor: The decoded video tensor reconstructed from latent space. + """ + if self._temporal_compress_factor == 1: + _, _, origin_T, _, _ = latent.shape + latent = rearrange(latent, "b c t h w -> (b t) c 1 h w") + B, _, T, _, _ = latent.shape + latent = self.transform_decode_state_shape(latent) + # use max_enc_batch_size to avoid OOM + if latent.shape[0] > self.max_dec_batch_size: + state = [] + for i in range(0, latent.shape[0], self.max_dec_batch_size): + state.append(super().decode(latent[i : i + self.max_dec_batch_size])) + state = torch.cat(state, dim=0) + else: + state = super().decode(latent) + assert state.shape[2] == self.pixel_chunk_duration + state = rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) + if self._temporal_compress_factor == 1: + return rearrange(state, "(b t) c 1 h w -> b c t h w", t=origin_T) + return state + + @property + def pixel_chunk_duration(self) -> int: + return self._pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + # return self._latent_chunk_duration + assert (self.pixel_chunk_duration - 1) % self.temporal_compression_factor == 0, ( + f"Pixel chunk duration {self.pixel_chunk_duration} is not divisible by latent chunk duration " + f"{self.latent_chunk_duration}" + ) + return (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1 + + @property + def temporal_compression_factor(self): + return self._temporal_compress_factor + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + assert ( + num_pixel_frames % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.pixel_chunk_duration}" + return num_pixel_frames // self.pixel_chunk_duration * self.latent_chunk_duration + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + assert ( + num_latent_frames % self.latent_chunk_duration == 0 + ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_chunk_duration}" + return num_latent_frames // self.latent_chunk_duration * self.pixel_chunk_duration + + +class VideoJITTokenizer(BasePretrainedVideoTokenizer, JITVAE, VideoTokenizerInterface): + """ + Instance of BasePretrainedVideoVAE that loads encoder and decoder from JIT scripted module file + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + name: str, + mean_std_fp: str, + latent_ch: int = 16, + is_bf16: bool = True, + spatial_compression_factor: int = 16, + temporal_compression_factor: int = 8, + pixel_chunk_duration: int = 17, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + spatial_resolution: str = "720", + ): + super().__init__(pixel_chunk_duration, temporal_compression_factor, max_enc_batch_size, max_dec_batch_size) + super(BasePretrainedVideoTokenizer, self).__init__(enc_fp, dec_fp, name, mean_std_fp, latent_ch, False, is_bf16) + + self._spatial_compression_factor = spatial_compression_factor + self._spatial_resolution = spatial_resolution + + @property + def spatial_compression_factor(self): + return self._spatial_compression_factor + + @property + def spatial_resolution(self) -> str: + return self._spatial_resolution + + + +class JointImageVideoTokenizer(BaseVAE, VideoTokenizerInterface): + def __init__( + self, + image_vae: torch.nn.Module, + video_vae: torch.nn.Module, + name: str, + latent_ch: int = 16, + squeeze_for_image: bool = True, + ): + super().__init__(latent_ch, name) + self.image_vae = image_vae + self.video_vae = video_vae + self.squeeze_for_image = squeeze_for_image + + def encode_image(self, state: torch.Tensor) -> torch.Tensor: + if self.squeeze_for_image: + return self.image_vae.encode(state.squeeze(2)).unsqueeze(2) + return self.image_vae.encode(state) + + def decode_image(self, latent: torch.Tensor) -> torch.Tensor: + if self.squeeze_for_image: + return self.image_vae.decode(latent.squeeze(2)).unsqueeze(2) + return self.image_vae.decode(latent) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = state.shape + if T == 1: + return self.encode_image(state) + + return self.video_vae.encode(state) + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = latent.shape + if T == 1: + return self.decode_image(latent) + return self.video_vae.decode(latent) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.image_vae.reset_dtype() + self.video_vae.reset_dtype() + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + return self.video_vae.get_latent_num_frames(num_pixel_frames) + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + return self.video_vae.get_pixel_num_frames(num_latent_frames) + + @property + def spatial_compression_factor(self): + return self.video_vae.spatial_compression_factor + + @property + def temporal_compression_factor(self): + return self.video_vae.temporal_compression_factor + + @property + def spatial_resolution(self) -> str: + return self.video_vae.spatial_resolution + + @property + def pixel_chunk_duration(self) -> int: + return self.video_vae.pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + return self.video_vae.latent_chunk_duration + + +class JointImageVideoSharedJITTokenizer(JointImageVideoTokenizer): + """ + First version of the ImageVideoVAE trained with Fitsum. + We have to use seperate mean and std for image and video due to non-causal nature of the model. + """ + + def __init__(self, image_vae: Module, video_vae: Module, name: str, latent_ch: int = 16): + super().__init__(image_vae, video_vae, name, latent_ch, squeeze_for_image=False) + assert isinstance(image_vae, JITVAE) + assert isinstance( + video_vae, VideoJITTokenizer + ), f"video_vae should be an instance of VideoJITVAE, got {type(video_vae)}" + # a hack to make the image_vae and video_vae share the same encoder and decoder + self.image_vae.encoder = self.video_vae.encoder + self.image_vae.decoder = self.video_vae.decoder diff --git a/cosmos_transfer1/diffusion/training/networks/__init__.py b/cosmos_transfer1/diffusion/training/networks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cosmos_transfer1/diffusion/training/networks/general_dit.py b/cosmos_transfer1/diffusion/training/networks/general_dit.py new file mode 100644 index 00000000..8b7061bf --- /dev/null +++ b/cosmos_transfer1/diffusion/training/networks/general_dit.py @@ -0,0 +1,1029 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. +It allows us easy to switch building blocks used and their order. Its instantiation includes +* transformer on fully flattened tokens +* factored spatial and temporal attention +* factored non-overlap spatial and temporal attention +* mixing of above attention types + +Limitations: + +* In favor of simplicity and cleanness, many ops are not fused and we can do better +* such as combining mutiple adaln MLPs into one inside one transformer block. +* we use reshape heavily, which may be not efficient when its occurs unnecessary CUDA memory copy + +Purpose: +* A prototype for testing different attention types and their combinations +* Idealy, we want to know where we should allocate our resources / FLOPS / memory via extensive empirical studies +""" + + +from collections.abc import Container +from typing import List, Optional, Tuple + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import nn +from torch.distributed import ProcessGroup, get_process_group_ranks +from torchvision import transforms + +from cosmos_transfer1.diffusion.conditioner import DataType +from cosmos_transfer1.diffusion.module.attention import get_normalization +from cosmos_transfer1.diffusion.training.modules.blocks import ( + DITBuildingBlock, + FinalLayer, + GeneralDITTransformerBlock, + PatchEmbed, + SDXLTimestepEmbedding, + SDXLTimesteps, +) +from cosmos_transfer1.diffusion.module.position_embedding import ( + LearnableEmb3D, + LearnableEmb3D_FPS_Aware, + LearnablePosEmbAxis, + SinCosPosEmb, + SinCosPosEmb_FPS_Aware, + SinCosPosEmbAxis, + VideoRopePosition3DEmb, + VideoRopePositionEmb, +) +from cosmos_transfer1.diffusion.training.tensor_parallel import gather_along_first_dim, scatter_along_first_dim +from cosmos_transfer1.utils import log + + +class GeneralDIT(nn.Module): + """ + A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. + Attributes: + max_img_h (int): Maximum height of the input images. + max_img_w (int): Maximum width of the input images. + max_frames (int): Maximum number of frames in the video sequence. + in_channels (int): Number of input channels (e.g., RGB channels for color images). + out_channels (int): Number of output channels. + patch_spatial (tuple of int): Spatial resolution of patches for input processing. + patch_temporal (int): Temporal resolution of patches for input processing. + concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. + block_config (str): Configuration of the transformer block, e.g., 'FA-CA-MLP', means + full attention, cross attention, and MLP in sequence in one transformer block. + model_channels (int): Base number of channels used throughout the model. + num_blocks (int): Number of residual blocks per resolution in the transformer. + num_heads (int): Number of heads in the multi-head self-attention layers. + spatial_attn_win_size (int): Window size for the spatial attention mechanism. + temporal_attn_win_size (int): Window size for the temporal attention mechanism. + mlp_ratio (float): Expansion ratio for the MLP (multi-layer perceptron) blocks in the transformer. + use_memory_save (bool): If True, utilizes checkpointing to reduce memory usage during training. (Deprecated) + use_checkpoint (bool): If True, utilizes checkpointing to reduce memory usage during training for all blocks. + crossattn_emb_channels (int): Number of embedding channels used in the cross-attention layers. + use_cross_attn_mask (bool): If True, applies a mask during cross-attention operations to manage sequence alignment. + pos_emb_cls (str): Type of positional embeddings used ('sincos' for sinusoidal or other types). + pos_emb_learnable (bool): Specifies if positional embeddings are learnable. + pos_emb_interpolation (str): Method used for interpolating positional embeddings, e.g., 'crop' for cropping adjustments. + block_x_format (str, optional): The format of the input tensor for the transformer block. Defaults to "BTHWD". Only support 'BTHWD' and 'THWBD'. + legacy_patch_emb (bool): If True, applies 3D convolutional layers for video inputs, otherwise, use Linear! This is for backward compatibility. + rope_h_extrapolation_ratio (float): Ratio of the height extrapolation for the rope positional embedding. + rope_w_extrapolation_ratio (float): Ratio of the width extrapolation for the rope positional embedding. + rope_t_extrapolation_ratio (float): Ratio of the temporal extrapolation for the rope positional embedding. + Note: + block_config support block type: + * spatial_sa, ssa: spatial self attention + * temporal_sa, tsa: temporal self attention + * cross_attn, ca: cross attention + * full_attn: full attention on all flatten tokens + * mlp, ff: feed forward block + * use '-' to separate different building blocks, e.g., 'FA-CA-MLP' means full attention, cross attention, and MLP in sequence in one transformer block. + + Example: + >>> # full attention, cross attention, and MLP + >>> option1_block_config = 'FA-CA-MLP' + >>> model_1 = GeneralDIT( + max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, + patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, + num_heads=16, mlp_ratio=4.0, + spatial_attn_win_size=1, temporal_attn_win_size=1, + block_config=option1_block_config + ) + >>> option2_block_config = 'SSA-CA-MLP-TSA-CA-MLP' + >>> model_2 = GeneralDIT( + max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, + patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, + num_heads=16, mlp_ratio=4.0, + spatial_attn_win_size=1, temporal_attn_win_size=1, + block_config=option2_block_config + ) + >>> # option3 model + >>> model_3 = GeneralDIT( + max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, + patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, + num_heads=16, mlp_ratio=4.0, + spatial_attn_win_size=1, temporal_attn_win_size=2, + block_config=option2_block_config + ) + >>> # Process input tensor through the model + >>> output = model(input_tensor) + """ + + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + block_config: str = "FA-CA-MLP", + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + window_block_indexes: list = [], # index for window attention block + window_sizes: list = [], # window size for window attention block in the order of T, H, W + spatial_attn_win_size: int = 1, + temporal_attn_win_size: int = 1, + mlp_ratio: float = 4.0, + use_memory_save: bool = False, + use_checkpoint: bool = False, + block_x_format: str = "BTHWD", + # cross attention settings + crossattn_emb_channels: int = 1024, + use_cross_attn_mask: bool = False, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + min_fps: int = 1, # 1 for getty video + max_fps: int = 30, # 120 for getty video but let's use 30 + additional_timestamp_channels: dict = None, # Follow SDXL, in format of {condition_name : dimension} + affline_emb_norm: bool = False, # whether or not to normalize the affine embedding + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + layer_mask: list = None, # whether or not a layer is used. For controlnet encoder + legacy_patch_emb: bool = True, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = False, + extra_per_block_abs_pos_emb_type: str = "sincos", + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + ) -> None: + super().__init__() + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.use_cross_attn_mask = use_cross_attn_mask + self.concat_padding_mask = concat_padding_mask + # positional embedding settings + self.pos_emb_cls = pos_emb_cls + self.pos_emb_learnable = pos_emb_learnable + self.pos_emb_interpolation = pos_emb_interpolation + self.min_fps = min_fps + self.max_fps = max_fps + self.additional_timestamp_channels = additional_timestamp_channels + self.affline_emb_norm = affline_emb_norm + self.legacy_patch_emb = legacy_patch_emb + self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio + self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio + self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower() + self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio + self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio + self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio + + self.build_patch_embed() + self.build_pos_embed() + self.cp_group = None + self.sequence_parallel = getattr(parallel_state, "sequence_parallel", False) + self.block_x_format = block_x_format + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + self.t_embedder = nn.Sequential( + SDXLTimesteps(model_channels), + SDXLTimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), + ) + + self.blocks = nn.ModuleDict() + self.block_config = block_config + self.use_memory_save = use_memory_save + self.use_checkpoint = use_checkpoint + + assert ( + len(window_block_indexes) == 0 or block_config == "FA-CA-MLP" + ), "Block config must be FA-CA-MLP if using a combination of window attention and global attention" + + layer_mask = [False] * num_blocks if layer_mask is None else layer_mask + assert ( + len(layer_mask) == num_blocks + ), f"Layer mask length {len(layer_mask)} does not match num_blocks {num_blocks}" + for idx in range(num_blocks): + if layer_mask[idx]: + continue + self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + block_config=block_config, + window_sizes=( + window_sizes if idx in window_block_indexes else [] + ), # There will be bug if using "WA-CA-MLP" + mlp_ratio=mlp_ratio, + spatial_attn_win_size=spatial_attn_win_size, + temporal_attn_win_size=temporal_attn_win_size, + x_format=self.block_x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + use_checkpoint=use_checkpoint, + ) + + self.build_decode_head() + self.build_additional_timestamp_embedder() + if self.affline_emb_norm: + log.critical("Building affine embedding normalization layer") + self.affline_norm = get_normalization("R", model_channels) + else: + self.affline_norm = nn.Identity() + self.init_weights() + + if self.use_memory_save: + log.critical("Using checkpointing to save memory! only verified in 14B base model training!") + for block in self.blocks.values(): + block.set_memory_save() + + def init_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding + nn.init.normal_(self.t_embedder[1].linear_1.weight, std=0.02) + if self.t_embedder[1].linear_1.bias is not None: + nn.init.constant_(self.t_embedder[1].linear_1.bias, 0) + nn.init.normal_(self.t_embedder[1].linear_2.weight, std=0.02) + if self.t_embedder[1].linear_2.bias is not None: + nn.init.constant_(self.t_embedder[1].linear_2.bias, 0) + + # Zero-out adaLN modulation layers in DiT blocks: + for transformer_block in self.blocks.values(): + for block in transformer_block.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + if block.adaLN_modulation[-1].bias is not None: + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Tensor parallel + if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: + self.initialize_tensor_parallel_weights() + + def initialize_tensor_parallel_weights(self): + """ + Initialize weights for tensor parallel layers. + + This function performs the following steps: + 1. Retrieves the tensor parallel rank. + 2. Saves the current random state. + 3. Sets a new random seed based on the tensor parallel rank. + 4. Initializes weights for attention and MLP layers in each block. + 5. Restores the original random state. + + The use of different random seeds for each rank ensures + unique initializations across parallel processes. + """ + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + # Save the current random state + rng_state = torch.get_rng_state() + + # Set a new random seed based on the tensor parallel rank + torch.manual_seed(tp_rank) + + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["full_attn", "fa", "cross_attn", "ca"]: + # Initialize weights for attention layers + torch.nn.init.xavier_uniform_(layer.block.attn.to_q[0].weight) + torch.nn.init.xavier_uniform_(layer.block.attn.to_k[0].weight) + torch.nn.init.xavier_uniform_(layer.block.attn.to_v[0].weight) + torch.nn.init.xavier_uniform_(layer.block.attn.to_out[0].weight) + elif layer.block_type in ["mlp", "ff"]: + # Initialize weights for MLP layers + torch.nn.init.xavier_uniform_(layer.block.layer1.weight) + torch.nn.init.xavier_uniform_(layer.block.layer2.weight) + else: + raise ValueError(f"Unknown block type {layer.block_type}") + + # Restore the original random state + torch.set_rng_state(rng_state) + + def build_decode_head(self): + self.final_layer = FinalLayer( + hidden_size=self.model_channels, + spatial_patch_size=self.patch_spatial, + temporal_patch_size=self.patch_temporal, + out_channels=self.out_channels, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + ) + + def build_patch_embed(self): + ( + concat_padding_mask, + in_channels, + patch_spatial, + patch_temporal, + model_channels, + ) = ( + self.concat_padding_mask, + self.in_channels, + self.patch_spatial, + self.patch_temporal, + self.model_channels, + ) + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + keep_spatio=True, + legacy_patch_emb=self.legacy_patch_emb, + ) + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + if self.legacy_patch_emb: + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def build_additional_timestamp_embedder(self): + if self.additional_timestamp_channels: + self.additional_timestamp_embedder = nn.ModuleDict() + for cond_name, cond_emb_channels in self.additional_timestamp_channels.items(): + log.critical( + f"Building additional timestamp embedder for {cond_name} with {cond_emb_channels} channels" + ) + self.additional_timestamp_embedder[cond_name] = nn.Sequential( + SDXLTimesteps(cond_emb_channels), + SDXLTimestepEmbedding(cond_emb_channels, cond_emb_channels), + ) + + def prepare_additional_timestamp_embedder(self, **kwargs): + condition_concat = [] + + for cond_name, embedder in self.additional_timestamp_embedder.items(): + condition_concat.append(embedder(kwargs[cond_name])[0]) + embedding = torch.cat(condition_concat, dim=1) + if embedding.shape[1] < self.model_channels: + embedding = nn.functional.pad(embedding, (0, self.model_channels - embedding.shape[1])) + return embedding + + def build_pos_embed(self): + if self.pos_emb_cls == "sincos": + cls_type = SinCosPosEmb + elif self.pos_emb_cls == "learnable": + cls_type = LearnableEmb3D + elif self.pos_emb_cls == "sincos_fps_aware": + cls_type = SinCosPosEmb_FPS_Aware + elif self.pos_emb_cls == "learnable_fps_aware": + cls_type = LearnableEmb3D_FPS_Aware + elif self.pos_emb_cls == "rope": + cls_type = VideoRopePositionEmb + elif self.pos_emb_cls == "rope3d": + cls_type = VideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + log.critical(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + max_fps=self.max_fps, + min_fps=self.min_fps, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + ) + self.pos_embedder = cls_type( + **kwargs, + ) + + if self.extra_per_block_abs_pos_emb: + assert self.extra_per_block_abs_pos_emb_type in [ + "sincos", + "learnable", + ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + if self.extra_per_block_abs_pos_emb_type == "sincos": + self.extra_pos_embedder = SinCosPosEmbAxis( + **kwargs, + ) + elif self.extra_per_block_abs_pos_emb_type == "learnable": + self.extra_pos_embedder = LearnablePosEmbAxis( + **kwargs, + ) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` + with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + return x_B_T_H_W_D, None, extra_pos_emb + + def decoder_head( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W] + crossattn_mask: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + del crossattn_emb, crossattn_mask + B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape + x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D") + x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D) + # This is to ensure x_BT_HW_D has the correct shape because + # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). + x_BT_HW_D = x_BT_HW_D.view( + B * T_before_patchify // self.patch_temporal, + H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, + -1, + ) + x_B_D_T_H_W = rearrange( + x_BT_HW_D, + "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + t=self.patch_temporal, + B=B, + ) + return x_B_D_T_H_W + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() + + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output + + def forward_blocks_regular( + self, + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ): + features = [] + for name, block in self.blocks.items(): + assert ( + self.blocks["block0"].x_format == block.x_format + ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" + x = block( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + ) + + # Extract features + block_idx = int(name.split("block")[-1]) + if block_idx in feature_indices: + B, C, T, H, W = original_shape + H = H // self.patch_spatial + W = W // self.patch_spatial + T = T // self.patch_temporal + if self.sequence_parallel: + x_feat = gather_along_first_dim(x, parallel_state.get_tensor_model_parallel_group()) + x_B_T_H_W_D = rearrange(x_feat, "(T H W) 1 1 B D -> B T H W D", T=T, H=H, W=W) + else: + x_feat = x + if self.blocks["block0"].x_format == "THWBD": + x_B_T_H_W_D = rearrange(x_feat, "T H W B D -> B T H W D", T=T, H=H, W=W) + elif self.blocks["block0"].x_format == "BTHWD": + x_B_T_H_W_D = x_feat + else: + raise ValueError(f"Unknown x_format {self.blocks[-1].x_format}") + + features.append(x_B_T_H_W_D) + + if x_ctrl is not None and name in x_ctrl: + x = x + x_ctrl[name] + # If we have all of the features, we can exit early + if return_features_early and len(features) == len(feature_indices): + return features + + if self.blocks["block0"].x_format == "THWBD": + x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") + elif self.blocks["block0"].x_format == "BTHWD": + x_B_T_H_W_D = x + else: + raise ValueError(f"Unknown x_format {self.blocks[-1].x_format}") + + x_B_D_T_H_W = self.decoder_head( + x_B_T_H_W_D=x_B_T_H_W_D, + emb_B_D=affline_emb_B_D, + crossattn_emb=None, + origin_shape=original_shape, + crossattn_mask=None, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + + if len(feature_indices) == 0: + # no features requested, return only the model output + return x_B_D_T_H_W + else: + # score and features; score, features + return x_B_D_T_H_W, features + + def forward_blocks_memory_save( + self, + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ): + x_before_gate = 0 + x_skip = rearrange(x, "T H W B D -> (T H W) B D") + assert self.blocks["block0"].x_format == "THWBD" + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_per_block_pos_emb = rearrange(extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "T H W B D -> (T H W) B D") + else: + extra_per_block_pos_emb = None + gate_L_B_D = 1.0 + + features = [] + for name, block in self.blocks.items(): + gate_L_B_D, x_before_gate, x_skip = block( + x_before_gate, + x_skip, + gate_L_B_D, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + extra_per_block_pos_emb=extra_per_block_pos_emb, + ) + + # Extract features. + # Convert the block index in the memory save mode to the block index in the regular mode. + block_idx = int(name.split("block")[-1]) - 1 + if block_idx in feature_indices: + B, C, T_before_patchify, H_before_patchify, W_before_patchify = original_shape + H = H_before_patchify // self.patch_spatial + W = W_before_patchify // self.patch_spatial + T = T_before_patchify // self.patch_temporal + if self.sequence_parallel: + x_feat = gather_along_first_dim(x_skip, parallel_state.get_tensor_model_parallel_group()) + x_B_T_H_W_D = rearrange(x_feat, "(T H W) 1 1 B D -> B T H W D", T=T, H=H, W=W) + else: + x_feat = x_skip + x_B_T_H_W_D = rearrange(x_feat, "(T H W) B D -> B T H W D", T=T, H=H, W=W) + + features.append(x_B_T_H_W_D) + + new_name = f"block{block_idx}" + if x_ctrl is not None and new_name in x_ctrl: + x_ctrl_ = x_ctrl[new_name] + x_ctrl_ = rearrange(x_ctrl_, "T H W B D -> (T H W) B D") + x_skip = x_skip + x_ctrl_ + # If we have all of the features, we can exit early + if return_features_early and len(features) == len(feature_indices): + return features + + x_THW_B_D_before_gate = x_before_gate + x_THW_B_D_skip = x_skip + + B, C, T_before_patchify, H_before_patchify, W_before_patchify = original_shape + x_BT_HW_D_before_gate = rearrange( + x_THW_B_D_before_gate, + "(T H W) B D -> (B T) (H W) D", + T=T_before_patchify // self.patch_temporal, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + ) + x_BT_HW_D_skip = rearrange( + x_THW_B_D_skip, + "(T H W) B D -> (B T) (H W) D", + T=T_before_patchify // self.patch_temporal, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + ) + + x_BT_HW_D = self.final_layer.forward_with_memory_save( + x_BT_HW_D_before_gate=x_BT_HW_D_before_gate, + x_BT_HW_D_skip=x_BT_HW_D_skip, + gate_L_B_D=gate_L_B_D, + emb_B_D=affline_emb_B_D, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + + # This is to ensure x_BT_HW_D has the correct shape because + # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). + x_BT_HW_D = x_BT_HW_D.view( + B * T_before_patchify // self.patch_temporal, + H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, + -1, + ) + x_B_D_T_H_W = rearrange( + x_BT_HW_D, + "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + t=self.patch_temporal, + B=B, + ) + if len(feature_indices) == 0: + # no features requested, return only the model output + return x_B_D_T_H_W + else: + # score and features; score, features + return x_B_D_T_H_W, features + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + x_ctrl: Optional[dict] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + feature_indices: Optional[Container[int]] = None, + return_features_early: bool = False, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + feature_indices: A set of feature indices (a set of integers) decides which blocks + to extract features from. If the set is non-empty, then features will be returned. + By default, feature_indices=None means extract no features. + return_features_early: If true, the forward pass returns the features once the set is complete. + This means the forward pass will not finish completely and no final output is returned. + condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to augment condition input, the lvg model will condition on the condition_video_augment_sigma value; + we need forward_before_blocks pass to the forward_before_blocks function. + """ + if feature_indices is None: + feature_indices = {} + if return_features_early and len(feature_indices) == 0: + # Exit immediately if user requested this. + return [] + + inputs = self.forward_before_blocks( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = ( + inputs["x"], + inputs["affline_emb_B_D"], + inputs["crossattn_emb"], + inputs["crossattn_mask"], + inputs["rope_emb_L_1_1_D"], + inputs["adaln_lora_B_3D"], + inputs["original_shape"], + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"] + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + + if self.use_memory_save: + return self.forward_blocks_memory_save( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ) + + return self.forward_blocks_regular( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ) + + @property + def fsdp_wrap_block_cls(self): + return DITBuildingBlock + + def enable_context_parallel(self, cp_group: ProcessGroup): + cp_ranks = get_process_group_ranks(cp_group) + cp_size = len(cp_ranks) + # Set these attributes for spliting the data after embedding. + self.cp_group = cp_group + # Set these attributes for computing the loss. + self.cp_size = cp_size + + self.pos_embedder.enable_context_parallel(cp_group) + if self.extra_per_block_abs_pos_emb: + self.extra_pos_embedder.enable_context_parallel(cp_group) + # Loop through the model to set up context parallel. + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["mlp", "ff"]: + continue + elif layer.block_type in ["cross_attn", "ca"]: + continue + else: + layer.block.attn.attn_op.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) + + log.debug(f"[CP] Enable context parallelism with size {cp_size}") + + def disable_context_parallel(self): + self.cp_group = None + self.cp_size = None + + self.pos_embedder.disable_context_parallel() + if self.extra_per_block_abs_pos_emb: + self.extra_pos_embedder.disable_context_parallel() + + # Loop through the model to disable context parallel. + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["mlp", "ff"]: + continue + elif layer.block_type in ["cross_attn", "ca"]: + continue + else: + layer.block.attn.attn_op.cp_group = None + layer.block.attn.attn_op.cp_ranks = None + layer.block.attn.attn_op.cp_stream = None + + log.debug("[CP] Disable context parallelism.") + + def enable_sequence_parallel(self): + self._set_sequence_parallel(True) + + def disable_sequence_parallel(self): + self._set_sequence_parallel(False) + + def _set_sequence_parallel(self, status: bool): + self.sequence_parallel = status + self.final_layer.sequence_parallel = status + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["full_attn", "fa", "cross_attn", "ca"]: + layer.block.attn.to_q[0].sequence_parallel = status + layer.block.attn.to_k[0].sequence_parallel = status + layer.block.attn.to_v[0].sequence_parallel = status + layer.block.attn.to_out[0].sequence_parallel = status + layer.block.attn.attn_op.sequence_parallel = status + elif layer.block_type in ["mlp", "ff"]: + layer.block.layer1.sequence_parallel = status + layer.block.layer2.sequence_parallel = status + else: + raise ValueError(f"Unknown block type {layer.block_type}") + + @property + def is_context_parallel_enabled(self): + return self.cp_group is not None diff --git a/cosmos_transfer1/diffusion/training/networks/general_dit_ctrl_enc.py b/cosmos_transfer1/diffusion/training/networks/general_dit_ctrl_enc.py new file mode 100644 index 00000000..a3474fb5 --- /dev/null +++ b/cosmos_transfer1/diffusion/training/networks/general_dit_ctrl_enc.py @@ -0,0 +1,402 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ControlNet Encoder based on GeneralDIT +""" + +from typing import List, Optional, Tuple + +import numpy as np +import torch +from einops import rearrange + +from megatron.core import parallel_state +from torch import nn +from torchvision import transforms + +from cosmos_transfer1.diffusion.conditioner import DataType +from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp +from cosmos_transfer1.diffusion.module.blocks import zero_module +from cosmos_transfer1.diffusion.training.modules.blocks import PatchEmbed +from cosmos_transfer1.diffusion.training.networks.general_dit_video_conditioned import VideoExtendGeneralDIT as GeneralDIT +from cosmos_transfer1.diffusion.training.tensor_parallel import scatter_along_first_dim + + +class GeneralDITEncoder(GeneralDIT): + """ + ControlNet Encoder based on GeneralDIT. Heavily borrowed from GeneralDIT with minor modifications. + """ + + def __init__(self, *args, **kwargs): + hint_channels = kwargs.pop("hint_channels", 16) + self.dropout_ctrl_branch = kwargs.pop("dropout_ctrl_branch", 0.5) + num_control_blocks = kwargs.pop("num_control_blocks", None) + if num_control_blocks is not None: + assert num_control_blocks > 0 and num_control_blocks <= kwargs["num_blocks"] + kwargs["layer_mask"] = [False] * num_control_blocks + [True] * (kwargs["num_blocks"] - num_control_blocks) + self.random_drop_control_blocks = kwargs.pop("random_drop_control_blocks", False) + super().__init__(*args, **kwargs) + num_blocks = self.num_blocks + model_channels = self.model_channels + layer_mask = kwargs.get("layer_mask", None) + layer_mask = [False] * num_blocks if layer_mask is None else layer_mask + self.layer_mask = layer_mask + self.hint_channels = hint_channels + self.build_hint_patch_embed() + hint_nf = [16, 16, 32, 32, 96, 96, 256] + nonlinearity = nn.SiLU() + input_hint_block = [nn.Linear(model_channels, hint_nf[0]), nonlinearity] + for i in range(len(hint_nf) - 1): + input_hint_block += [nn.Linear(hint_nf[i], hint_nf[i + 1]), nonlinearity] + self.input_hint_block = nn.Sequential(*input_hint_block) + # Initialize weights + self.init_weights() + self.zero_blocks = nn.ModuleDict() + for idx in range(num_blocks): + if layer_mask[idx]: + continue + self.zero_blocks[f"block{idx}"] = zero_module(nn.Linear(model_channels, model_channels)) + self.input_hint_block.append(zero_module(nn.Linear(hint_nf[-1], model_channels))) + + def _set_sequence_parallel(self, status: bool): + self.zero_blocks.sequence_parallel = status + self.input_hint_block.sequence_parallel = status + super()._set_sequence_parallel(status) + + def build_hint_patch_embed(self): + concat_padding_mask, in_channels, patch_spatial, patch_temporal, model_channels = ( + self.concat_padding_mask, + self.hint_channels, + self.patch_spatial, + self.patch_temporal, + self.model_channels, + ) + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder2 = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + keep_spatio=True, + legacy_patch_emb=self.legacy_patch_emb, + ) + + if self.legacy_patch_emb: + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.x_embedder2.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def prepare_hint_embedded_sequence( + self, x_B_C_T_H_W: torch.Tensor, fps: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if self.concat_padding_mask: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[2], 1, 1)], + dim=1, + ) + + x_B_T_H_W_D = self.x_embedder2(x_B_C_T_H_W) + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps) + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + return x_B_T_H_W_D, None + + def encode_hint( + self, + hint: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + ) -> torch.Tensor: + assert hint.size(1) <= self.hint_channels, f"Expected hint channels <= {self.hint_channels}, got {hint.size(1)}" + if hint.size(1) < self.hint_channels: + padding_shape = list(hint.shape) + padding_shape[1] = self.hint_channels - hint.size(1) + hint = torch.cat([hint, torch.zeros(*padding_shape, dtype=hint.dtype, device=hint.device)], dim=1) + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + + hint_B_T_H_W_D, _ = self.prepare_hint_embedded_sequence(hint, fps=fps, padding_mask=padding_mask) + + hint = rearrange(hint_B_T_H_W_D, "B T H W D -> T H W B D") + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + T, H, W, B, D = hint.shape + hint = hint.view(T * H * W, 1, 1, B, -1) + hint = scatter_along_first_dim(hint, tp_group) + + guided_hint = self.input_hint_block(hint) + return guided_hint + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + hint_key: Optional[str] = None, + base_model: Optional[nn.Module] = None, + control_weight: Optional[float] = 1.0, + num_layers_to_use: Optional[int] = -1, + condition_video_input_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + """ + # record the input as they are replaced in this forward + x_input = x + crossattn_emb_input = crossattn_emb + crossattn_mask_input = crossattn_mask + condition_video_input_mask_input = condition_video_input_mask + + hint = kwargs.pop(hint_key) + if hint is None: + print("using none hint") + return base_model.net.forward( + x=x_input, + timesteps=timesteps, + crossattn_emb=crossattn_emb_input, + crossattn_mask=crossattn_mask_input, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_input_mask=condition_video_input_mask_input, + **kwargs, + ) + if hasattr(self, "hint_encoders"): # for multicontrol + guided_hints = [] + for i in range(hint.shape[1]): + self.input_hint_block = self.hint_encoders[i].input_hint_block + self.pos_embedder = self.hint_encoders[i].pos_embedder + self.x_embedder2 = self.hint_encoders[i].x_embedder2 + guided_hints += [self.encode_hint(hint[:, i], fps=fps, padding_mask=padding_mask, data_type=data_type)] + else: + guided_hints = self.encode_hint(hint, fps=fps, padding_mask=padding_mask, data_type=data_type) + guided_hints = torch.chunk(guided_hints, hint.shape[0] // x.shape[0], dim=3) + # Only support multi-control at inference time + assert len(guided_hints) == 1 or not torch.is_grad_enabled() + + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + + B, C, T, H, W = x.shape + if data_type == DataType.VIDEO: + if condition_video_input_mask is not None: + if self.cp_group is not None: + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=2, cp_group=self.cp_group + ) + input_list = [x, condition_video_input_mask] + x = torch.cat(input_list, dim=1) + + elif data_type == DataType.IMAGE: + # For image, we dont have condition_video_input_mask, or condition_video_pose + # We need to add the extra channel for video condition mask + padding_channels = self.in_channels - x.shape[1] + x = torch.cat([x, torch.zeros((B, padding_channels, T, H, W), dtype=x.dtype, device=x.device)], dim=1) + else: + assert x.shape[1] == self.in_channels, f"Expected {self.in_channels} channels, got {x.shape[1]}" + + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + outs = {} + + # (Experimental, not used in the released model) if also training base model, sometimes drop the + # controlnet branch to only train base branch. This is to prevent the network become dependent on + # controlnet branch and make control weight useless. + is_training = torch.is_grad_enabled() + is_training_base_model = any(p.requires_grad for p in base_model.parameters()) + if is_training and is_training_base_model: + coin_flip = torch.rand(B).to(x.device) > self.dropout_ctrl_branch # prob for only training base model + if self.blocks["block0"].x_format == "THWBD": + coin_flip = coin_flip[None, None, None, :, None] + elif self.blocks["block0"].x_format == "BTHWD": + coin_flip = coin_flip[:, None, None, None, None] + else: + coin_flip = 1 + + num_control_blocks = self.layer_mask.index(True) + num_layers_to_use = num_control_blocks + control_gate_per_layer = [i < num_layers_to_use for i in range(num_control_blocks)] + + if isinstance(control_weight, torch.Tensor): + if control_weight.ndim == 0: # Single scalar tensor + control_weight = [float(control_weight)] * len(guided_hints) + elif control_weight.ndim == 1: # List of scalar weights + control_weight = [float(w) for w in control_weight] + else: # Spatial-temporal weight maps + control_weight = [w for w in control_weight] # Keep as tensor + else: + control_weight = [control_weight] * len(guided_hints) + + x_before_blocks = x.clone() + for i, guided_hint in enumerate(guided_hints): + x = x_before_blocks + if hasattr(self, "hint_encoders"): # for multicontrol + blocks = self.hint_encoders[i].blocks + zero_blocks = self.hint_encoders[i].zero_blocks + t_embedder = self.hint_encoders[i].t_embedder + affline_norm = self.hint_encoders[i].affline_norm + self.x_embedder = self.hint_encoders[i].x_embedder + self.extra_pos_embedder = self.hint_encoders[i].extra_pos_embedder + else: + blocks = self.blocks + zero_blocks = self.zero_blocks + t_embedder = self.t_embedder + affline_norm = self.affline_norm + + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, fps=fps, padding_mask=padding_mask + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() + + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + for idx, (name, block) in enumerate(blocks.items()): + assert ( + blocks["block0"].x_format == block.x_format + ), f"First block has x_format {blocks[0].x_format}, got {block.x_format}" + x = block( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + ) + if guided_hint is not None: + x = x + guided_hint + guided_hint = None + + gate = control_gate_per_layer[idx] + if isinstance(control_weight[i], (float, int)) or control_weight[i].ndim < 2: + hint_val = zero_blocks[name](x) * control_weight[i] * coin_flip * gate + else: # Spatial-temporal weights [num_controls, B, 1, T, H, W] + control_feat = zero_blocks[name](x) + # Get current feature dimensions + weight_map = control_weight[i] # [B, 1, T, H, W] + # Reshape to match THWBD format + weight_map = weight_map.permute(2, 3, 4, 0, 1) # [T, H, W, B, 1] + weight_map = weight_map.view(T * H * W, 1, 1, B, 1) + + if self.sequence_parallel: + weight_map = scatter_along_first_dim(weight_map, tp_group) + + hint_val = control_feat * weight_map * coin_flip * gate + + if name not in outs: + outs[name] = hint_val + else: + outs[name] += hint_val + + output = base_model.net.forward( + x=x_input, + timesteps=timesteps, + crossattn_emb=crossattn_emb_input, + crossattn_mask=crossattn_mask_input, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + x_ctrl=outs, + condition_video_input_mask=condition_video_input_mask_input, + **kwargs, + ) + return output diff --git a/cosmos_transfer1/diffusion/training/networks/general_dit_video_conditioned.py b/cosmos_transfer1/diffusion/training/networks/general_dit_video_conditioned.py new file mode 100644 index 00000000..ac91d4b0 --- /dev/null +++ b/cosmos_transfer1/diffusion/training/networks/general_dit_video_conditioned.py @@ -0,0 +1,259 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import nn + +from cosmos_transfer1.diffusion.conditioner import DataType +from cosmos_transfer1.diffusion.module.blocks import SDXLTimesteps, SDXLTimestepEmbedding +from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp +from cosmos_transfer1.diffusion.training.networks.general_dit import GeneralDIT +from cosmos_transfer1.diffusion.training.tensor_parallel import scatter_along_first_dim +from cosmos_transfer1.utils import log + + +class VideoExtendGeneralDIT(GeneralDIT): + def __init__(self, *args, in_channels=16 + 1, add_augment_sigma_embedding=False, **kwargs): + self.add_augment_sigma_embedding = add_augment_sigma_embedding + + # extra channel for video condition mask + super().__init__(*args, in_channels=in_channels, **kwargs) + log.info(f"VideoExtendGeneralDIT in_channels: {in_channels}") + + def build_additional_timestamp_embedder(self): + super().build_additional_timestamp_embedder() + if self.add_augment_sigma_embedding: + log.info("Adding augment sigma embedding") + self.augment_sigma_embedder = nn.Sequential( + SDXLTimesteps(self.model_channels), + SDXLTimestepEmbedding(self.model_channels, self.model_channels, use_adaln_lora=self.use_adaln_lora), + ) + + def init_weights(self): + if self.add_augment_sigma_embedding: + # Initialize timestep embedding for augment sigma + nn.init.normal_(self.augment_sigma_embedder[1].linear_1.weight, std=0.02) + if self.augment_sigma_embedder[1].linear_1.bias is not None: + nn.init.constant_(self.augment_sigma_embedder[1].linear_1.bias, 0) + nn.init.normal_(self.augment_sigma_embedder[1].linear_2.weight, std=0.02) + if self.augment_sigma_embedder[1].linear_2.bias is not None: + nn.init.constant_(self.augment_sigma_embedder[1].linear_2.bias, 0) + + super().init_weights() # Call this last since it wil call TP weight init + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + video_cond_bool: Optional[torch.Tensor] = None, + condition_video_indicator: Optional[torch.Tensor] = None, + condition_video_input_mask: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + condition_video_pose: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Args: + condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation + condition_video_pose: (B, 1, T, H, W) tensor of pose condition + """ + B, C, T, H, W = x.shape + + if data_type == DataType.VIDEO: + assert ( + condition_video_input_mask is not None + ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" + if self.cp_group is not None: + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=2, cp_group=self.cp_group + ) + condition_video_indicator = split_inputs_cp( + condition_video_indicator, seq_dim=2, cp_group=self.cp_group + ) + if condition_video_pose is not None: + condition_video_pose = split_inputs_cp(condition_video_pose, seq_dim=2, cp_group=self.cp_group) + + input_list = [x, condition_video_input_mask] + if condition_video_pose is not None: + if condition_video_pose.shape[2] > T: + log.warning( + f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" + ) + condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() + input_list.append(condition_video_pose) + x = torch.cat( + input_list, + dim=1, + ) + + if data_type == DataType.IMAGE: + # For image, we dont have condition_video_input_mask, or condition_video_pose + # We need to add the extra channel for video condition mask + padding_channels = self.in_channels - x.shape[1] + x = torch.cat([x, torch.zeros((B, padding_channels, T, H, W), dtype=x.dtype, device=x.device)], dim=1) + else: + assert x.shape[1] == self.in_channels, f"Expected {self.in_channels} channels, got {x.shape[1]}" + return super().forward( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + + condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() + if self.add_augment_sigma_embedding: + if condition_video_augment_sigma is None: + # Handling image case + # Note: for video case, when there is not condition frames, we also set it as zero, see + # the augment_conditional_latent_frames function in DiffusionV2WModel and ExtendDiffusionModel. + assert data_type == DataType.IMAGE, "condition_video_augment_sigma is required for video data type" + condition_video_augment_sigma = torch.zeros_like(timesteps.flatten()) + + affline_augment_sigma_emb_B_D, adaln_lora_sigma_emb_B_3D = self.augment_sigma_embedder( + condition_video_augment_sigma.flatten() + ) + affline_emb_B_D = affline_emb_B_D + affline_augment_sigma_emb_B_D + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output diff --git a/cosmos_transfer1/utils/config.py b/cosmos_transfer1/utils/config.py index 8f257354..8c15f398 100644 --- a/cosmos_transfer1/utils/config.py +++ b/cosmos_transfer1/utils/config.py @@ -33,82 +33,9 @@ from cosmos_transfer1.utils.lazy_config import LazyDict from cosmos_transfer1.utils.misc import Color from cosmos_transfer1.utils.callback import EMAModelCallback, ProgressBarCallback +from cosmos_transfer1.utils.ddp_config import DDPConfig, make_freezable -T = TypeVar("T") -def _is_attrs_instance(obj: object) -> bool: - """ - Helper function to check if an object is an instance of an attrs-defined class. - - Args: - obj: The object to check. - - Returns: - bool: True if the object is an instance of an attrs-defined class, False otherwise. - """ - return hasattr(obj, "__attrs_attrs__") - - -def make_freezable(cls: T) -> T: - """ - A decorator that adds the capability to freeze instances of an attrs-defined class. - - NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need - to hack on a "_is_frozen" attribute. - - This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime. - Once an instance is frozen, its attributes cannot be changed. It also recursively freezes - any attrs-defined objects that are attributes of the class. - - Usage: - @make_freezable - @attrs.define(slots=False) - class MyClass: - attribute1: int - attribute2: str - - obj = MyClass(1, 'a') - obj.freeze() # Freeze the instance - obj.attribute1 = 2 # Raises AttributeError - - Args: - cls: The class to be decorated. - - Returns: - The decorated class with added freezing capability. - """ - - if not hasattr(cls, "__dict__"): - raise TypeError( - "make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped " - "class was defined with `@attrs.define(slots=False)`" - ) - - original_setattr = cls.__setattr__ - - def setattr_override(self, key, value) -> None: # noqa: ANN001 - """ - Override __setattr__ to allow modifications during initialization - and prevent modifications once the instance is frozen. - """ - if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen": - raise AttributeError("Cannot modify frozen instance") - original_setattr(self, key, value) # type: ignore - - cls.__setattr__ = setattr_override # type: ignore - - def freeze(self: object) -> None: - """ - Freeze the instance and all its attrs-defined attributes. - """ - for _, value in attrs.asdict(self, recurse=False).items(): - if _is_attrs_instance(value) and hasattr(value, "freeze"): - value.freeze() - self._is_frozen = True # type: ignore - - cls.freeze = freeze # type: ignore - - return cls def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str: @@ -168,17 +95,6 @@ class EMAConfig: torch_compile_buffer_renaming: bool = False -@make_freezable -@attrs.define(slots=False) -class DDPConfig: - # Traverse the computation graph to find parameters that don't receive gradients. - find_unused_parameters: bool = False - # Set to True if the computation graph does not change during the whole training loop. - static_graph: bool = True - # Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere. - broadcast_buffers: bool = True - - @make_freezable @attrs.define(slots=False) class CuDNNConfig: diff --git a/cosmos_transfer1/utils/ddp_config.py b/cosmos_transfer1/utils/ddp_config.py new file mode 100644 index 00000000..d9713c7b --- /dev/null +++ b/cosmos_transfer1/utils/ddp_config.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import attrs + +from typing import TypeVar + +T = TypeVar("T") + +def _is_attrs_instance(obj: object) -> bool: + """ + Helper function to check if an object is an instance of an attrs-defined class. + + Args: + obj: The object to check. + + Returns: + bool: True if the object is an instance of an attrs-defined class, False otherwise. + """ + return hasattr(obj, "__attrs_attrs__") + + +def make_freezable(cls: T) -> T: + """ + A decorator that adds the capability to freeze instances of an attrs-defined class. + + NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need + to hack on a "_is_frozen" attribute. + + This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime. + Once an instance is frozen, its attributes cannot be changed. It also recursively freezes + any attrs-defined objects that are attributes of the class. + + Usage: + @make_freezable + @attrs.define(slots=False) + class MyClass: + attribute1: int + attribute2: str + + obj = MyClass(1, 'a') + obj.freeze() # Freeze the instance + obj.attribute1 = 2 # Raises AttributeError + + Args: + cls: The class to be decorated. + + Returns: + The decorated class with added freezing capability. + """ + + if not hasattr(cls, "__dict__"): + raise TypeError( + "make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped " + "class was defined with `@attrs.define(slots=False)`" + ) + + original_setattr = cls.__setattr__ + + def setattr_override(self, key, value) -> None: # noqa: ANN001 + """ + Override __setattr__ to allow modifications during initialization + and prevent modifications once the instance is frozen. + """ + if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen": + raise AttributeError("Cannot modify frozen instance") + original_setattr(self, key, value) # type: ignore + + cls.__setattr__ = setattr_override # type: ignore + + def freeze(self: object) -> None: + """ + Freeze the instance and all its attrs-defined attributes. + """ + for _, value in attrs.asdict(self, recurse=False).items(): + if _is_attrs_instance(value) and hasattr(value, "freeze"): + value.freeze() + self._is_frozen = True # type: ignore + + cls.freeze = freeze # type: ignore + + return cls + +@make_freezable +@attrs.define(slots=False) +class DDPConfig: + # Traverse the computation graph to find parameters that don't receive gradients. + find_unused_parameters: bool = False + # Set to True if the computation graph does not change during the whole training loop. + static_graph: bool = True + # Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere. + broadcast_buffers: bool = True \ No newline at end of file diff --git a/cosmos_transfer1/utils/distributed.py b/cosmos_transfer1/utils/distributed.py index 4afecade..68aac2f8 100644 --- a/cosmos_transfer1/utils/distributed.py +++ b/cosmos_transfer1/utils/distributed.py @@ -21,14 +21,23 @@ import functools import os from datetime import timedelta -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, TypeVar import pynvml import torch import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP from cosmos_transfer1.utils import log from cosmos_transfer1.utils.device import Device +from cosmos_transfer1.utils.ddp_config import DDPConfig + +try: + from megatron.core import parallel_state +except ImportError: + print("Megatron-core is not installed.") + +T = TypeVar("T") def init() -> int | None: @@ -127,6 +136,52 @@ def barrier() -> None: dist.barrier() +def rank0_first(func: Callable) -> Callable: + """run the function on rank 0 first, then on other ranks.""" + + @functools.wraps(func) + def wrapper(*args, **kwargs): # noqa: ANN202 + if is_rank0(): + result = func(*args, **kwargs) + barrier() + if not is_rank0(): + result = func(*args, **kwargs) + return result + + return wrapper + + +def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DDP: + """Wraps the model to enable data parallalism for training across multiple GPU devices. + + Args: + config_ddp (DDPConfig): The data parallel config. + model (torch.nn.Module): The PyTorch module. + + Returns: + model (torch.nn.Module | DistributedDataParallel): The data parallel model wrapper + if distributed environment is available, otherwise return the original model. + """ + if dist.is_available() and dist.is_initialized(): + local_rank = int(os.getenv("LOCAL_RANK", 0)) + try: + ddp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + except Exception as e: + log.info(e) + log.info("parallel_state not initialized, treating all GPUs equally for DDP") + ddp_group = None + + model = DDP( + model, + device_ids=[local_rank], + output_device=local_rank, + find_unused_parameters=config_ddp.find_unused_parameters, + static_graph=config_ddp.static_graph, + broadcast_buffers=config_ddp.broadcast_buffers, + process_group=ddp_group, + ) + return model + class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): """This extends torch.nn.parallel.DistributedDataParallel with .training_step(). diff --git a/cosmos_transfer1/utils/log.py b/cosmos_transfer1/utils/log.py index 822a9755..2c4975a0 100644 --- a/cosmos_transfer1/utils/log.py +++ b/cosmos_transfer1/utils/log.py @@ -76,6 +76,20 @@ def get_machine_format() -> str: return machine_format +def init_loguru_file(path: str) -> None: + machine_format = get_machine_format() + message_format = get_message_format() + logger.add( + path, + encoding="utf8", + level=LEVEL, + format="[{time:MM-DD HH:mm:ss}|" f"{machine_format}" f"{message_format}", + rotation="100 MB", + filter=lambda result: _rank0_only_filter(result) or not RANK0_ONLY, + enqueue=True, + ) + + def get_message_format() -> str: message_format = "{level}|{extra[relative_path]}:{line}:{function}] {message}" return message_format diff --git a/cosmos_transfer1/utils/misc.py b/cosmos_transfer1/utils/misc.py index 39bd30ea..f3b53b30 100644 --- a/cosmos_transfer1/utils/misc.py +++ b/cosmos_transfer1/utils/misc.py @@ -135,6 +135,18 @@ def serialize(data: Any) -> Any: data = str(data) return data +def print_environ_variables(env_vars: list[str]) -> None: + """Print a specific list of environment variables. + + Args: + env_vars (list[str]): List of specified environment variables. + """ + for env_var in env_vars: + if env_var in os.environ: + log.info(f"Environment variable {Color.green(env_var)}: {Color.yellow(os.environ[env_var])}") + else: + log.warning(f"Environment variable {Color.green(env_var)} not set!") + def set_random_seed(seed: int, by_rank: bool = False) -> None: """Set random seed. This includes random, numpy, Pytorch. @@ -224,6 +236,68 @@ def wrapper(*args, **kwargs): # noqa: ANN202 return wrapper # type: ignore +class TrainingTimer: + """Timer for timing the execution of code, aggregating over multiple training iterations. + + It is used as a context manager to measure the execution time of code and store the timing results + for each function. The context managers can be nested. + + Attributes: + results (dict): A dictionary to store timing results for various code. + + Example: + timer = Timer() + for i in range(100): + with timer("func_a"): + func_a() + avg_time = sum(timer.results["func_a"]) / len(timer.results["func_a"]) + print(f"func_a() took {avg_time} seconds.") + """ + + def __init__(self) -> None: + self.results = dict() + self.average_results = dict() + self.start_time = [] + self.func_stack = [] + self.reset() + + def reset(self) -> None: + self.results = {key: [] for key in self.results} + + def __enter__(self) -> TrainingTimer: + self.start_time.append(time.time()) + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 + end_time = time.time() + result = end_time - self.start_time.pop() + key = self.func_stack.pop() + self.results.setdefault(key, []) + self.results[key].append(result) + + def __call__(self, func_name: str) -> TrainingTimer: + self.func_stack.append(func_name) + return self + + def __getattr__(self, func_name: str) -> TrainingTimer: + return self.__call__(func_name) + + def nested(self, func_name: str) -> TrainingTimer: + return self.__call__(func_name) + + def compute_average_results(self) -> dict[str, float]: + results = dict() + for key, value_list in self.results.items(): + results[key] = sum(value_list) / len(value_list) + return results + + +def timeout_handler(timeout_period: float, signum: int, frame: int) -> None: + # What to do when the process gets stuck. For now, we simply end the process. + error_message = f"Timeout error: more than {timeout_period} seconds passed since the last iteration." + raise TimeoutError(error_message) + + class Color: """A convenience class to colorize strings in the console. diff --git a/scripts/convert_ckpt_fsdp_to_tp.py b/scripts/convert_ckpt_fsdp_to_tp.py index 901d0fa5..3975a1dd 100644 --- a/scripts/convert_ckpt_fsdp_to_tp.py +++ b/scripts/convert_ckpt_fsdp_to_tp.py @@ -105,7 +105,7 @@ def convert_fsdp_to_tp(path_in: str, path_out: str) -> None: easy_io.dump({'grad_scaler': {}, 'iteration': 0}, f"{path_out}.pt") for i in tqdm(range(TP_SIZE)): state_dict = {"model": state_dicts[i]} - easy_io.dump(state_dict, f"{path_out}_model_mp_{i}.pt") + easy_io.dump(state_dict, f"{path_out}_mp_{i}.pt") if __name__ == "__main__": @@ -116,9 +116,9 @@ def convert_fsdp_to_tp(path_in: str, path_out: str) -> None: python convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control.pt This will save the Tensor Parallel (TP) checkpoints as 8 files in the same directory: - checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_model_mp_0.pt + checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_mp_0.pt ... - checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_model_mp_7.pt + checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_mp_7.pt ''' if len(sys.argv) != 2: print("Usage: python convert_ckpt_fsdp_to_tp.py ") From 56e976f924efd2feee55b8b7abaa917e099fcf6d Mon Sep 17 00:00:00 2001 From: Qianli Ma Date: Wed, 16 Apr 2025 11:36:56 -0700 Subject: [PATCH 10/10] feat+fix: multiple fixes + refinements to README --- README.md | 1 - .../checkpointer/ddp_checkpointer.py | 3 +- cosmos_transfer1/checkpointer/fast_tp.py | 5 +- .../diffusion/config/training/callbacks.py | 1 + .../experiment/ctrl_7b_tp_121frames.py | 8 +- .../config/training/registry_extra.py | 10 +- .../diffusion/config/training/tokenizer.py | 43 +- .../datasets/augmentors/control_input.py | 242 +++++++++- .../datasets/example_transfer_dataset.py | 30 +- cosmos_transfer1/diffusion/training/README.md | 3 + .../training/modules/pretrained_vae.py | 447 +++++++++--------- cosmos_transfer1/utils/distributed.py | 86 +++- cosmos_transfer1/utils/trainer.py | 2 +- examples/training_cosmos_transfer_7b.md | 55 ++- scripts/convert_ckpt_fsdp_to_tp.py | 8 +- scripts/download_diffusion_example_data.py | 121 +++++ 16 files changed, 750 insertions(+), 315 deletions(-) create mode 100644 cosmos_transfer1/diffusion/training/README.md create mode 100644 scripts/download_diffusion_example_data.py diff --git a/README.md b/README.md index 1ba9fbf7..68ca8754 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,6 @@ Please refer to [INSTALL.md](INSTALL.md) for general instructions on environment ### Post-train pre-trained Cosmos-Transfer1 models -* Post-train diffusion-based Text2World models using custom datasets [with multi-node support]Coming soon * [Post-train pre-trained Cosmos-Transfer1-7B [Depth|Segmentation|Edge|Vis|Keypoint]](examples/training_cosmos_transfer_7b.md) **[with multi-GPU support]** * Post-train pre-trained Cosmos-Transfer1-7B-Sample-AV [LiDAR]: Coming soon * Post-train pre-trained Cosmos-Transfer1-7B-Sample-AV [HDMap]: Coming soon diff --git a/cosmos_transfer1/checkpointer/ddp_checkpointer.py b/cosmos_transfer1/checkpointer/ddp_checkpointer.py index 6bab42c4..12da16cd 100644 --- a/cosmos_transfer1/checkpointer/ddp_checkpointer.py +++ b/cosmos_transfer1/checkpointer/ddp_checkpointer.py @@ -171,7 +171,8 @@ def _save_worker(self, state_dict: dict[str, StateDictItemPath], checkpoint_file item.save_path, fast_backend=True, # optional for fast backend, cpu heavy ) - self.print(f"Saved {key} to {item.save_path}") + abs_path = os.path.abspath(item.save_path) + self.print(f"Saved {key} to {item.save_path}, abspath = {abs_path}") except Exception as e: self.print(f"Failed to save {key} to {item.save_path}: {str(e)}") raise # Re-raise the exception after logging diff --git a/cosmos_transfer1/checkpointer/fast_tp.py b/cosmos_transfer1/checkpointer/fast_tp.py index 16732924..62818447 100644 --- a/cosmos_transfer1/checkpointer/fast_tp.py +++ b/cosmos_transfer1/checkpointer/fast_tp.py @@ -66,7 +66,7 @@ def load_broadcast_state_dict( sorted_resume_keys = sorted(resume_keys) for key in sorted_resume_keys: _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) - _state_dict = easy_io.load(_ckpt_path, fast_backend=True, backend_key=self.load_s3_backend_key) + _state_dict = easy_io.load(_ckpt_path, weights_only=False) state_dict[key] = _state_dict self.print(f"Loaded checkpoint from: {_ckpt_path}") distributed.barrier() @@ -85,8 +85,7 @@ def _save_worker(self, state_dict: dict[str, StateDictItemPath], checkpoint_file easy_io.dump( item.state_dict, item.save_path, - fast_backend=False, # too cpu heavy - backend_key=self.save_s3_backend_key, + # fast_backend=False, # too cpu heavy ) self.print(f"Saved {key} to {item.save_path}") except Exception as e: diff --git a/cosmos_transfer1/diffusion/config/training/callbacks.py b/cosmos_transfer1/diffusion/config/training/callbacks.py index 6270c714..50d44d96 100644 --- a/cosmos_transfer1/diffusion/config/training/callbacks.py +++ b/cosmos_transfer1/diffusion/config/training/callbacks.py @@ -25,5 +25,6 @@ progress_bar=L(ProgressBarCallback)(), grad_clip=L(GradClip)(fsdp_enabled=True, model_key="model"), low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + # for the first 1000 iterations, log the iteration speed per iteration, after that, log every 200 iterations iter_speed=L(IterSpeed)(every_n=200, hit_thres=1000), ) \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py index 9d264ed2..eb957c14 100644 --- a/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py +++ b/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py @@ -90,14 +90,14 @@ def make_ctrlnet_config_7b_training( checkpoint=dict( load_path=pretrain_model_path, # Modify load_path as needed if you do post-training (fine-tuning). If training from scratch, leave it empty. broadcast_via_filesystem=True, - save_iter=1000, + save_iter=1000, # 1000 iterations per checkpoint. Update as needed. load_training_state=False, strict_resume=True, keys_not_to_resume=[], ), trainer=dict( distributed_parallelism="ddp", - logging_iter=200, + logging_iter=200, # will log iter speed, loss, etc. every 200 iterations. (Will log per-iteration speed for the first 1000 iterations.) max_iter=999_999_999, timestamp_seed=True, ), @@ -111,12 +111,12 @@ def make_ctrlnet_config_7b_training( loss_reduce='mean', latent_shape=[ 16, - (num_frames - 1) // 8 + 1, + (num_frames - 1) // 8 + 1, # for 121 frames, this is 16 88, 160, ], base_load_from=dict( - load_path=os.path.join("checkpoints", COSMOS_TRANSFER1_7B_CHECKPOINT, "checkpoints_tp", "base_model_mp_*.pt") + load_path=os.path.join("checkpoints", COSMOS_TRANSFER1_7B_CHECKPOINT, "checkpoints_tp", "base_model_model_mp_*.pt") ), # modify as needed. This is the TP version of base model ckpt (that's frozen during training). finetune_base_model=False, hint_mask=[True] * len(CTRL_HINT_KEYS_COMB[hint_key]), diff --git a/cosmos_transfer1/diffusion/config/training/registry_extra.py b/cosmos_transfer1/diffusion/config/training/registry_extra.py index 0ddc70a4..c1f9c29c 100644 --- a/cosmos_transfer1/diffusion/config/training/registry_extra.py +++ b/cosmos_transfer1/diffusion/config/training/registry_extra.py @@ -27,7 +27,7 @@ from cosmos_transfer1.diffusion.training.networks.general_dit_ctrl_enc import GeneralDITEncoder from cosmos_transfer1.diffusion.training.networks.general_dit import GeneralDIT from cosmos_transfer1.diffusion.config.training.tokenizer import get_cosmos_diffusion_tokenizer_comp8x8x8 -from cosmos_transfer1.diffusion.config.registry import register_tokenizer +# from cosmos_transfer1.diffusion.config.registry import register_tokenizer from cosmos_transfer1.utils.lazy_config import LazyCall as L from cosmos_transfer1.utils.lazy_config import LazyDict import copy @@ -86,6 +86,14 @@ def register_conditioner_ctrlnet(cs): node=VideoConditionerFpsSizePaddingWithCtrlConfig, ) +def register_tokenizer(cs): + cs.store( + group="tokenizer", + package="model.tokenizer", + name="cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624", + node=get_cosmos_diffusion_tokenizer_comp8x8x8(resolution="720", chunk_duration=121), + ) + def register_configs(): cs = ConfigStore.instance() diff --git a/cosmos_transfer1/diffusion/config/training/tokenizer.py b/cosmos_transfer1/diffusion/config/training/tokenizer.py index c580d722..503e8bdb 100644 --- a/cosmos_transfer1/diffusion/config/training/tokenizer.py +++ b/cosmos_transfer1/diffusion/config/training/tokenizer.py @@ -41,8 +41,7 @@ def get_cosmos_diffusion_tokenizer_comp8x8x8(resolution: str, chunk_duration: in temporal_compression_factor = 8 spatial_compression_factor = 8 - return L(JointImageVideoSharedJITTokenizer)( - video_vae=L(VideoJITTokenizer)( + return L(VideoJITTokenizer)( name="cosmos_1_0_diffusion_tokenizer", enc_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit", dec_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit", @@ -53,16 +52,30 @@ def get_cosmos_diffusion_tokenizer_comp8x8x8(resolution: str, chunk_duration: in temporal_compression_factor=temporal_compression_factor, spatial_compression_factor=spatial_compression_factor, spatial_resolution=resolution, - ), - image_vae=L(JITVAE)( - name="cosmos_1_0_diffusion_tokenizer", - enc_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit", - dec_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit", - mean_std_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt", - latent_ch=16, - is_image=False, - is_bf16=True, - ), - name="cosmos_1_0_diffusion_tokenizer", - latent_ch=16, - ) \ No newline at end of file + ) + + # return L(JointImageVideoSharedJITTokenizer)( + # video_vae=L(VideoJITTokenizer)( + # name="cosmos_1_0_diffusion_tokenizer", + # enc_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit", + # dec_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit", + # mean_std_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt", + # latent_ch=16, + # is_bf16=True, + # pixel_chunk_duration=pixel_chunk_duration, + # temporal_compression_factor=temporal_compression_factor, + # spatial_compression_factor=spatial_compression_factor, + # spatial_resolution=resolution, + # ), + # image_vae=L(JITVAE)( + # name="cosmos_1_0_diffusion_tokenizer", + # enc_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit", + # dec_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit", + # mean_std_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt", + # latent_ch=16, + # is_image=False, + # is_bf16=True, + # ), + # name="cosmos_1_0_diffusion_tokenizer", + # latent_ch=16, + # ) \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py b/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py index 4f279c49..d0bd7159 100644 --- a/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py +++ b/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py @@ -33,6 +33,11 @@ LaplacianOfGaussianConfig, MedianBlurConfig, ) +from cosmos_transfer1.diffusion.datasets.augmentors.human_keypoint_utils import ( + convert_coco_to_openpose, + openpose134_skeleton, + coco_wholebody_133_skeleton, +) from cosmos_transfer1.diffusion.datasets.augmentors.guided_filter import FastGuidedFilter from cosmos_transfer1.utils import log @@ -1043,6 +1048,197 @@ def __init__( self.hand_as_separate_channel = args.get("hand_as_separate_channel", False) self.kpt_thr = args.get("kpt_thr", 0.6) self.line_width = args.get("human_kpt_line_width", 4) + + def denormalize_pose_kpts(self, pose_kps: np.ndarray, h: int, w: int): + """ + pose_kps has shape = (#keypoints, 2) + or (#keypoints, 3) where the last dim is the confidence score. + """ + if pose_kps is not None: + assert pose_kps.shape[-1] == 3, "pose_kps must have shape (#keypoints, 3)" + out = pose_kps * np.array([w, h, 1]) + return out + else: + return None + + def draw_skeleton( + self, + img: np.ndarray, + keypoints: np.ndarray, + scores: np.ndarray, + kpt_thr: float = 0.6, + openpose_format: bool = False, + radius: int = 2, + line_width: int = 4, + ): + skeleton_topology = openpose134_skeleton if openpose_format else coco_wholebody_133_skeleton + assert len(keypoints.shape) == 2 + keypoint_info, skeleton_info = ( + skeleton_topology["keypoint_info"], + skeleton_topology["skeleton_info"], + ) + vis_kpt = [s >= kpt_thr for s in scores] + + if self.hand_as_separate_channel: + img_hand = np.zeros_like(img) + hand_idx_start = 92 if openpose_format else 91 # all idx after this are hand keypoints + + link_dict = {} + for i, kpt_info in keypoint_info.items(): + kpt_color = tuple(kpt_info["color"]) + link_dict[kpt_info["name"]] = kpt_info["id"] + + kpt = keypoints[i] + + if vis_kpt[i]: + img = cv2.circle(img, (int(kpt[0]), int(kpt[1])), int(radius), kpt_color, -1) + + if self.hand_as_separate_channel: + if i >= hand_idx_start: + img_hand = cv2.circle(img_hand, (int(kpt[0]), int(kpt[1])), int(radius), kpt_color, -1) + + for i, ske_info in skeleton_info.items(): + link = ske_info["link"] + pt0, pt1 = link_dict[link[0]], link_dict[link[1]] + + if vis_kpt[pt0] and vis_kpt[pt1]: + link_color = ske_info["color"] + kpt0 = keypoints[pt0] + kpt1 = keypoints[pt1] + + img = cv2.line( + img, (int(kpt0[0]), int(kpt0[1])), (int(kpt1[0]), int(kpt1[1])), link_color, thickness=line_width + ) + + if self.hand_as_separate_channel: + if pt0 >= hand_idx_start and pt1 >= hand_idx_start: + img_hand = cv2.line( + img_hand, + (int(kpt0[0]), int(kpt0[1])), + (int(kpt1[0]), int(kpt1[1])), + link_color, + thickness=line_width, + ) + + if self.hand_as_separate_channel: + img = np.concatenate([img, img_hand], axis=-1) # [h,w,6] + return img + + def plot_person_kpts( + self, + person_dict: dict, + pose_vis_img: np.ndarray, + h: int, + w: int, + kpt_thr: float = 0.6, + openpose_format: bool = True, + line_width: int = 4, + ) -> np.ndarray: + """ + plot a single person + in-place update the pose image + """ + try: + body_keypoints = self.denormalize_pose_kpts(person_dict.get("body-keypoints"), h, w) + hand_keypoints = self.denormalize_pose_kpts(person_dict.get("hand-keypoints"), h, w) + except Exception as e: + log.error(f"Error in denormalizing keypoints: {e}") + + assert ( + body_keypoints is not None and hand_keypoints is not None + ), "Both body and hand keypoints must be present." + # all_keypoints: shape=(133, 3). following coco-fullbody skeleton config. 3 channels are x, y, confidence + all_keypoints = np.vstack([body_keypoints, hand_keypoints]) + kpts, scores = all_keypoints[..., :2], all_keypoints[..., -1] + if openpose_format: + kpts, scores = convert_coco_to_openpose(kpts, scores) + + try: + # [h,w,3] or # [h,w,6] if hand_as_separate_channel + pose_vis_img = self.draw_skeleton( + pose_vis_img, kpts, scores, kpt_thr=kpt_thr, openpose_format=openpose_format, line_width=line_width + ) + except ValueError as e: + log.error(f"Error in draw_skeleton func, {e}") + + return pose_vis_img + + def plot_kpt_video( + self, + kpts_np_dict: dict, + h: int, + w: int, + kpt_thr: float = 0.6, + openpose_format: bool = True, + line_width: int = 4, + ) -> np.ndarray: + """ + Plots a single *frame* for all persons in the frame. + + The raw human keypoint annotation are numpy arrays of pixel coordinates of the joints. + This function plots the keypoints on a black background to form a 3-channel image compatible with controlnet. + + Args: + kpts_np_dict (dict): A dict of keypoint annotations. Each value is a frame's annotation (a list of per-person dict). + H (int): height of the image + W (int): width of the image + openpose_format (bool): whether the convert the coco-wholebody133 keypoints keypoints to openpose format and also + plot in the openpose format (basically add neck keypoint, remove toe keypoints). + Returns: + np.ndarray: keypoints of plotted on black background, shape = (C, T, H, W) C=3, or 6 if hand_as_separate_channel + """ + T = len(kpts_np_dict) + + out = np.empty((3, T, h, w), dtype=np.uint8) # memory save op + + for idx, (t, kpts_np_frame) in enumerate(kpts_np_dict.items()): + pose_vis_img = np.zeros([h, w, 3]) + + # add each person's keypoints to this frame's pose image + for person_dict in kpts_np_frame: + self.plot_person_kpts( + person_dict, + pose_vis_img, + h, + w, + kpt_thr=kpt_thr, + openpose_format=openpose_format, + line_width=line_width, + ) # (h, w, 3) + + out[:, idx, :, :] = pose_vis_img.astype(np.uint8).transpose(2, 0, 1) + + return out + + def get_kpts_from_annotations(self, annotation_dict: dict, total_frames: int, frame_indices: list) -> dict: + """ + For legacy data the annotations are done for chunks of every N frames (N=4). + This function repeats each chunk's first frame annotation to the rest frames + so that they become 'per-frame' and are ControlNet compatible. + + If the data is already per-frame annotated, then no need to call this. + Args: + annotation_dict (dict): Original annotations annotated every chunk_size frames. + Each value is a list of dicts, where each dict has many + human attributes. Here we only keep keypoint-relevant keys. + total_frames (int): Total number of frames in the video. + frame_indices (list[int]): Indices of the video frames sampled from the the original video. + + Returns: + dict: extended annotations for all frames. + """ + annotated_frame_idxs = sorted(list(annotation_dict.keys())) + chunk_size = annotated_frame_idxs[1] - annotated_frame_idxs[0] + assert chunk_size == 1, "Only support videos that have human annotations for every frame" + + # each person's dict can contain many irrelevant annotations (e.g. seg masks), here we only keep pose kpts + annotation_dict_simpler = { + key: [{k: v for k, v in sub_dict.items() if k in self.control_key_names} for sub_dict in sub_list] + for key, sub_list in annotation_dict.items() + } + annotation_dict_simpler = {idx: annotation_dict_simpler[idx] for idx in frame_indices} + + return annotation_dict_simpler def __call__(self, data_dict: dict) -> dict: """ @@ -1065,13 +1261,55 @@ def __call__(self, data_dict: dict) -> dict: } Note that for the same person, their idx in the per-frame list isn't guaranteed to be consistent. """ - if "control_input_keypoint" in data_dict: + if "control_input_human_kpts" in data_dict: # already processed log.info( - f"control_input_keypoint already processed, shape={data_dict['control_input_keypoint'].shape}, dtype={data_dict['control_input_keypoint'].dtype}, value range: {data_dict['control_input_keypoint'].min()}, {data_dict['control_input_keypoint'].max()}" + f"control_input_human_kpts already processed, shape={data_dict['control_input_human_kpts'].shape}, dtype={data_dict['control_input_human_kpts'].dtype}, value range: {data_dict['control_input_human_kpts'].min()}, {data_dict['control_input_human_kpts'].max()}" ) return data_dict + human_annotations = data_dict.pop("keypoint") + frames = data_dict["video"] + _, T, H, W = frames.shape + + # the frames here are a randomly sampled (e.g. 121-frame) chunk from the original video + # so we need to accordingly only use the human annotations of the sampled frames. + frame_start = data_dict["frame_start"] + frame_end = data_dict["frame_end"] + frame_indices = np.arange(frame_start, frame_end).tolist() + assert ( + len(frame_indices) == T + ), f"frame_indices length {len(frame_indices)} != T {T}, likely due to video decoder using different fps, i.e. sample with stride. Need to return frame indices from video decoder." + + try: + # same dict format as `human_annotations` but now every frame has an annotation + kpts_nparray_dict = self.get_kpts_from_annotations(human_annotations, T, frame_indices) + except ValueError as e: + log.error(f"Error in loading kpts from annotated data: {e}") + kpts_nparray_dict = {} + raise e + + try: + # Colored human keypoints plotted on black background. All persons in the same frame are plotted together. + # np.array of shape: [C, T, H, W]. + kpts_cond_video = self.plot_kpt_video( + kpts_nparray_dict, + H, + W, + kpt_thr=self.kpt_thr, + openpose_format=self.use_openpose_format, + line_width=self.line_width, + ) + except ValueError as e: + log.error(f"Error in plot_kpt_video: {e}") + kpts_cond_video = np.zeros_like(frames) + + key_out = self.output_keys[0] + + data_dict[key_out] = torch.from_numpy(kpts_cond_video) + return data_dict + + class AddControlInputUpscale(Augmentor): """ Add control input to the data dictionary. control input are expanded to 3-channels diff --git a/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py b/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py index c856a5fe..1a8fdd1a 100644 --- a/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py +++ b/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py @@ -134,8 +134,8 @@ def _load_control_data(self, sample): f"Depth video {ctrl_path} has fewer frames than main video" # Load the corresponding frames - depth_frames = vr.get_batch(frame_ids).asnumpy() - depth_frames = torch.from_numpy(depth_frames).permute(0, 3, 1, 2) # [T,C,H,W] + depth_frames = vr.get_batch(frame_ids).asnumpy() # [T,H,W,C] + depth_frames = torch.from_numpy(depth_frames).permute(3, 0, 1, 2) # [C,T,H,W], same as rgb video data_dict["depth"] = { "video": depth_frames, "frame_start": frame_ids[0], @@ -165,7 +165,7 @@ def __getitem__(self, index): # Process video frames video = torch.from_numpy(frames).permute(3, 0, 1, 2) # [T,H,W,C] -> [C,T,H,W] - aspect_ratio = detect_aspect_ratio((video.shape[3], video.shape[2])) # expects (w, h) + aspect_ratio = detect_aspect_ratio((video.shape[3], video.shape[2])) # expects (W, H) # Basic data data["video"] = video @@ -179,17 +179,17 @@ def __getitem__(self, index): # Load T5 embeddings with open(data["video_name"]["t5_embedding_path"], "rb") as f: t5_embedding = pickle.load(f)[0] - data["t5_text_embeddings"] = torch.from_numpy(t5_embedding).cuda() - data["t5_text_mask"] = torch.ones(512, dtype=torch.int64).cuda() + data["t5_text_embeddings"] = torch.from_numpy(t5_embedding) #.cuda() + data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) #.cuda() # Add metadata data["fps"] = fps data["frame_start"] = frame_ids[0] data["frame_end"] = frame_ids[-1] + 1 data["num_frames"] = self.sequence_length - data["image_size"] = torch.tensor([704, 1280, 704, 1280]).cuda() - data["padding_mask"] = torch.zeros(1, 704, 1280).cuda() - + data["image_size"] = torch.tensor([704, 1280, 704, 1280]) #.cuda() + data["padding_mask"] = torch.zeros(1, 704, 1280) #.cuda() + if self.ctrl_type: ctrl_data = self._load_control_data({ "ctrl_path": os.path.join( @@ -202,10 +202,12 @@ def __getitem__(self, index): if ctrl_data is None: # Control data loading failed index = np.random.randint(len(self.video_paths)) continue - data.update(ctrl_data) + data.update(ctrl_data) - # Apply augmentations including control input processing - for aug_name, aug_fn in self.augmentor.items(): + # The ctrl_data above is the 'raw' data loaded (e.g. a loaded segmentation pkl). + # Next, we process it into the control input "video" tensor that the model expects. + # This is done in the augmentor. + for _, aug_fn in self.augmentor.items(): data = aug_fn(data) return data @@ -232,8 +234,8 @@ def __str__(self): ''' Sanity check for the dataset. ''' - control_input_key = "control_input_edge" - visualize_control_input = False + control_input_key = "control_input_keypoint" + visualize_control_input = True dataset = ExampleTransferDataset( dataset_dir="datasets/hdvila/", @@ -262,5 +264,5 @@ def __str__(self): if visualize_control_input: import imageio control_input_tensor = data[control_input_key].permute(1, 2, 3, 0).cpu().numpy() - video_name = "control_input_edge.mp4" + video_name = f"{control_input_key}.mp4" imageio.mimsave(video_name, control_input_tensor, fps=24) diff --git a/cosmos_transfer1/diffusion/training/README.md b/cosmos_transfer1/diffusion/training/README.md new file mode 100644 index 00000000..1a499b1c --- /dev/null +++ b/cosmos_transfer1/diffusion/training/README.md @@ -0,0 +1,3 @@ +# Training Modules + +This folder contains specialized versions of models and modules optimized for training. While some components (for example, the `GeneralDIT` defined in `training/networks/general_dit.py`) may appear duplicated from elsewhere in the repository, they include training-specific functionality including gradient checkpointing, training steps, tensor parallel and sequence parallel support, etc. \ No newline at end of file diff --git a/cosmos_transfer1/diffusion/training/modules/pretrained_vae.py b/cosmos_transfer1/diffusion/training/modules/pretrained_vae.py index 21769467..20649947 100644 --- a/cosmos_transfer1/diffusion/training/modules/pretrained_vae.py +++ b/cosmos_transfer1/diffusion/training/modules/pretrained_vae.py @@ -126,6 +126,157 @@ def is_chunk_overlap(self): return False +class BasePretrainedVideoTokenizer(ABC): + """ + Base class for a pretrained video tokenizer that handles chunking of video data for efficient processing. + + Args: + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + temporal_compress_factor (int): The factor by which the video data is temporally compressed during processing. + max_enc_batch_size (int): The maximum batch size to process in one go during encoding to avoid memory overflow. + max_dec_batch_size (int): The maximum batch size to process in one go during decoding to avoid memory overflow. + + The class introduces parameters for managing temporal chunks (`pixel_chunk_duration` and `temporal_compress_factor`) + which define how video data is subdivided and compressed during the encoding and decoding processes. The + `max_enc_batch_size` and `max_dec_batch_size` parameters allow processing in smaller batches to handle memory + constraints. + """ + + def __init__( + self, + pixel_chunk_duration: int = 17, + temporal_compress_factor: int = 8, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + ): + self._pixel_chunk_duration = pixel_chunk_duration + self._temporal_compress_factor = temporal_compress_factor + self.max_enc_batch_size = max_enc_batch_size + self.max_dec_batch_size = max_dec_batch_size + + def register_mean_std(self, mean_std_fp: str) -> None: + latent_mean, latent_std = torch.load(mean_std_fp, map_location="cuda", weights_only=True) + latent_mean = latent_mean.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] + latent_std = latent_std.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] + + target_shape = [1, self.latent_ch, self.latent_chunk_duration, 1, 1] + + self.register_buffer( + "latent_mean", + latent_mean.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + self.register_buffer( + "latent_std", + latent_std.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + + def transform_encode_state_shape(self, state: torch.Tensor) -> torch.Tensor: + """ + Rearranges the input state tensor to the required shape for encoding video data. Mainly for chunk based encoding + """ + B, C, T, H, W = state.shape + assert ( + T % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.pixel_chunk_duration}" + return rearrange(state, "b c (n t) h w -> (b n) c t h w", t=self.pixel_chunk_duration) + + def transform_decode_state_shape(self, latent: torch.Tensor) -> None: + B, _, T, _, _ = latent.shape + assert ( + T % self.latent_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.latent_chunk_duration}" + return rearrange(latent, "b c (n t) h w -> (b n) c t h w", t=self.latent_chunk_duration) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + if self._temporal_compress_factor == 1: + _, _, origin_T, _, _ = state.shape + state = rearrange(state, "b c t h w -> (b t) c 1 h w") + B, C, T, H, W = state.shape + state = self.transform_encode_state_shape(state) + # use max_enc_batch_size to avoid OOM + if state.shape[0] > self.max_enc_batch_size: + latent = [] + for i in range(0, state.shape[0], self.max_enc_batch_size): + latent.append(super().encode(state[i : i + self.max_enc_batch_size])) + latent = torch.cat(latent, dim=0) + else: + latent = super().encode(state) + + latent = rearrange(latent, "(b n) c t h w -> b c (n t) h w", b=B) + if self._temporal_compress_factor == 1: + latent = rearrange(latent, "(b t) c 1 h w -> b c t h w", t=origin_T) + return latent + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes a batch of latent representations into video frames by applying temporal chunking. Similar to encode, + it handles video data by processing smaller temporal chunks to reconstruct the original video dimensions. + + It can also decode single frame image data. + + Args: + latent (torch.Tensor): The latent space tensor containing encoded video data. + + Returns: + torch.Tensor: The decoded video tensor reconstructed from latent space. + """ + if self._temporal_compress_factor == 1: + _, _, origin_T, _, _ = latent.shape + latent = rearrange(latent, "b c t h w -> (b t) c 1 h w") + B, _, T, _, _ = latent.shape + latent = self.transform_decode_state_shape(latent) + # use max_enc_batch_size to avoid OOM + if latent.shape[0] > self.max_dec_batch_size: + state = [] + for i in range(0, latent.shape[0], self.max_dec_batch_size): + state.append(super().decode(latent[i : i + self.max_dec_batch_size])) + state = torch.cat(state, dim=0) + else: + state = super().decode(latent) + assert state.shape[2] == self.pixel_chunk_duration + state = rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) + if self._temporal_compress_factor == 1: + return rearrange(state, "(b t) c 1 h w -> b c t h w", t=origin_T) + return state + + @property + def pixel_chunk_duration(self) -> int: + return self._pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + # return self._latent_chunk_duration + assert (self.pixel_chunk_duration - 1) % self.temporal_compression_factor == 0, ( + f"Pixel chunk duration {self.pixel_chunk_duration} is not divisible by latent chunk duration " + f"{self.latent_chunk_duration}" + ) + return (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1 + + @property + def temporal_compression_factor(self): + return self._temporal_compress_factor + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + assert ( + num_pixel_frames % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.pixel_chunk_duration}" + return num_pixel_frames // self.pixel_chunk_duration * self.latent_chunk_duration + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + assert ( + num_latent_frames % self.latent_chunk_duration == 0 + ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_chunk_duration}" + return num_latent_frames // self.latent_chunk_duration * self.pixel_chunk_duration + + class BasePretrainedImageVAE(BaseVAE): """ A base class for pretrained Variational Autoencoder (VAE) that loads mean and standard deviation values @@ -377,231 +528,81 @@ def load_decoder(self, dec_fp: str) -> None: # self.vae.to(self.dtype) -class SDVAE(BaseVAE): - def __init__(self, batch_size=16, count_std: bool = False, is_downsample: bool = True) -> None: - super().__init__(channel=4, name="sd_vae") - self.dtype = torch.bfloat16 - self.register_buffer( - "scale", - torch.tensor([4.17, 4.62, 3.71, 3.28], dtype=self.dtype).reciprocal().reshape(1, -1, 1, 1), - persistent=False, - ) - self.register_buffer( - "bias", - -1.0 * torch.tensor([5.81, 3.25, 0.12, -2.15], dtype=self.dtype).reshape(1, -1, 1, 1) * self.scale, - persistent=False, - ) - self.batch_size = batch_size - self.count_std = count_std - self.is_downsample = is_downsample - self.load_vae() - self.reset_dtype() - - def reset_dtype(self, *args, **kwargs): - del args, kwargs - self.vae.to(self.dtype) - - @rank0_first - def load_vae(self) -> None: - os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" - os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" - import diffusers - - vae_name = "stabilityai/sd-vae-ft-mse" - try: - vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, local_files_only=True) - except: # noqa: E722 - # Could not load the model from cache; try without local_files_only. - vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name) - self.vae = vae.eval().requires_grad_(False) - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - """ - state : pixel range [-1, 1] - """ - if self.is_downsample: - _h, _w = state.shape[-2:] - state = F.interpolate(state, size=(_h // 2, _w // 2), mode="bilinear", align_corners=False) - in_dtype = state.dtype - state = state.to(self.dtype) - state = (state + 1.0) / 2.0 - latent_dist = self.vae.encode(state)["latent_dist"] - mean, std = latent_dist.mean, latent_dist.std - if self.count_std: - latent = mean + torch.randn_like(mean) * std - else: - latent = mean - latent = latent * self.scale - latent = latent + self.bias - return latent.to(in_dtype) - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - in_dtype = latent.dtype - latent = latent.to(self.dtype) - latent = latent - self.bias - latent = latent / self.scale - latent = torch.cat([self.vae.decode(batch)["sample"] for batch in latent.split(self.batch_size)]) - if self.is_downsample: - _h, _w = latent.shape[-2:] - latent = F.interpolate(latent, size=(_h * 2, _w * 2), mode="bilinear", align_corners=False) - return latent.to(in_dtype) * 2 - 1.0 - - @property - def spatial_compression_factor(self) -> int: - return 8 - - -class BasePretrainedVideoTokenizer(ABC): - """ - Base class for a pretrained video tokenizer that handles chunking of video data for efficient processing. - - Args: - pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. - temporal_compress_factor (int): The factor by which the video data is temporally compressed during processing. - max_enc_batch_size (int): The maximum batch size to process in one go during encoding to avoid memory overflow. - max_dec_batch_size (int): The maximum batch size to process in one go during decoding to avoid memory overflow. - - The class introduces parameters for managing temporal chunks (`pixel_chunk_duration` and `temporal_compress_factor`) - which define how video data is subdivided and compressed during the encoding and decoding processes. The - `max_enc_batch_size` and `max_dec_batch_size` parameters allow processing in smaller batches to handle memory - constraints. - """ - - def __init__( - self, - pixel_chunk_duration: int = 17, - temporal_compress_factor: int = 8, - max_enc_batch_size: int = 8, - max_dec_batch_size: int = 4, - ): - self._pixel_chunk_duration = pixel_chunk_duration - self._temporal_compress_factor = temporal_compress_factor - self.max_enc_batch_size = max_enc_batch_size - self.max_dec_batch_size = max_dec_batch_size - - def register_mean_std(self, mean_std_fp: str) -> None: - latent_mean, latent_std = torch.load(mean_std_fp, map_location="cuda", weights_only=True) - latent_mean = latent_mean.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] - latent_std = latent_std.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] - - target_shape = [1, self.latent_ch, self.latent_chunk_duration, 1, 1] - - self.register_buffer( - "latent_mean", - latent_mean.to(self.dtype).reshape(*target_shape), - persistent=False, - ) - self.register_buffer( - "latent_std", - latent_std.to(self.dtype).reshape(*target_shape), - persistent=False, - ) - - def transform_encode_state_shape(self, state: torch.Tensor) -> torch.Tensor: - """ - Rearranges the input state tensor to the required shape for encoding video data. Mainly for chunk based encoding - """ - B, C, T, H, W = state.shape - assert ( - T % self.pixel_chunk_duration == 0 - ), f"Temporal dimension {T} is not divisible by chunk_length {self.pixel_chunk_duration}" - return rearrange(state, "b c (n t) h w -> (b n) c t h w", t=self.pixel_chunk_duration) - - def transform_decode_state_shape(self, latent: torch.Tensor) -> None: - B, _, T, _, _ = latent.shape - assert ( - T % self.latent_chunk_duration == 0 - ), f"Temporal dimension {T} is not divisible by chunk_length {self.latent_chunk_duration}" - return rearrange(latent, "b c (n t) h w -> (b n) c t h w", t=self.latent_chunk_duration) - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - if self._temporal_compress_factor == 1: - _, _, origin_T, _, _ = state.shape - state = rearrange(state, "b c t h w -> (b t) c 1 h w") - B, C, T, H, W = state.shape - state = self.transform_encode_state_shape(state) - # use max_enc_batch_size to avoid OOM - if state.shape[0] > self.max_enc_batch_size: - latent = [] - for i in range(0, state.shape[0], self.max_enc_batch_size): - latent.append(super().encode(state[i : i + self.max_enc_batch_size])) - latent = torch.cat(latent, dim=0) - else: - latent = super().encode(state) - - latent = rearrange(latent, "(b n) c t h w -> b c (n t) h w", b=B) - if self._temporal_compress_factor == 1: - latent = rearrange(latent, "(b t) c 1 h w -> b c t h w", t=origin_T) - return latent - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - """ - Decodes a batch of latent representations into video frames by applying temporal chunking. Similar to encode, - it handles video data by processing smaller temporal chunks to reconstruct the original video dimensions. - - It can also decode single frame image data. - - Args: - latent (torch.Tensor): The latent space tensor containing encoded video data. - - Returns: - torch.Tensor: The decoded video tensor reconstructed from latent space. - """ - if self._temporal_compress_factor == 1: - _, _, origin_T, _, _ = latent.shape - latent = rearrange(latent, "b c t h w -> (b t) c 1 h w") - B, _, T, _, _ = latent.shape - latent = self.transform_decode_state_shape(latent) - # use max_enc_batch_size to avoid OOM - if latent.shape[0] > self.max_dec_batch_size: - state = [] - for i in range(0, latent.shape[0], self.max_dec_batch_size): - state.append(super().decode(latent[i : i + self.max_dec_batch_size])) - state = torch.cat(state, dim=0) - else: - state = super().decode(latent) - assert state.shape[2] == self.pixel_chunk_duration - state = rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) - if self._temporal_compress_factor == 1: - return rearrange(state, "(b t) c 1 h w -> b c t h w", t=origin_T) - return state - - @property - def pixel_chunk_duration(self) -> int: - return self._pixel_chunk_duration - - @property - def latent_chunk_duration(self) -> int: - # return self._latent_chunk_duration - assert (self.pixel_chunk_duration - 1) % self.temporal_compression_factor == 0, ( - f"Pixel chunk duration {self.pixel_chunk_duration} is not divisible by latent chunk duration " - f"{self.latent_chunk_duration}" - ) - return (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1 +# class SDVAE(BaseVAE): +# def __init__(self, batch_size=16, count_std: bool = False, is_downsample: bool = True) -> None: +# super().__init__(channel=4, name="sd_vae") +# self.dtype = torch.bfloat16 +# self.register_buffer( +# "scale", +# torch.tensor([4.17, 4.62, 3.71, 3.28], dtype=self.dtype).reciprocal().reshape(1, -1, 1, 1), +# persistent=False, +# ) +# self.register_buffer( +# "bias", +# -1.0 * torch.tensor([5.81, 3.25, 0.12, -2.15], dtype=self.dtype).reshape(1, -1, 1, 1) * self.scale, +# persistent=False, +# ) +# self.batch_size = batch_size +# self.count_std = count_std +# self.is_downsample = is_downsample +# self.load_vae() +# self.reset_dtype() - @property - def temporal_compression_factor(self): - return self._temporal_compress_factor +# def reset_dtype(self, *args, **kwargs): +# del args, kwargs +# self.vae.to(self.dtype) - def get_latent_num_frames(self, num_pixel_frames: int) -> int: - if num_pixel_frames == 1: - return 1 - assert ( - num_pixel_frames % self.pixel_chunk_duration == 0 - ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.pixel_chunk_duration}" - return num_pixel_frames // self.pixel_chunk_duration * self.latent_chunk_duration +# @rank0_first +# def load_vae(self) -> None: +# os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" +# os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" +# import diffusers + +# vae_name = "stabilityai/sd-vae-ft-mse" +# try: +# vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, local_files_only=True) +# except: # noqa: E722 +# # Could not load the model from cache; try without local_files_only. +# vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name) +# self.vae = vae.eval().requires_grad_(False) + +# @torch.no_grad() +# def encode(self, state: torch.Tensor) -> torch.Tensor: +# """ +# state : pixel range [-1, 1] +# """ +# if self.is_downsample: +# _h, _w = state.shape[-2:] +# state = F.interpolate(state, size=(_h // 2, _w // 2), mode="bilinear", align_corners=False) +# in_dtype = state.dtype +# state = state.to(self.dtype) +# state = (state + 1.0) / 2.0 +# latent_dist = self.vae.encode(state)["latent_dist"] +# mean, std = latent_dist.mean, latent_dist.std +# if self.count_std: +# latent = mean + torch.randn_like(mean) * std +# else: +# latent = mean +# latent = latent * self.scale +# latent = latent + self.bias +# return latent.to(in_dtype) + +# @torch.no_grad() +# def decode(self, latent: torch.Tensor) -> torch.Tensor: +# in_dtype = latent.dtype +# latent = latent.to(self.dtype) +# latent = latent - self.bias +# latent = latent / self.scale +# latent = torch.cat([self.vae.decode(batch)["sample"] for batch in latent.split(self.batch_size)]) +# if self.is_downsample: +# _h, _w = latent.shape[-2:] +# latent = F.interpolate(latent, size=(_h * 2, _w * 2), mode="bilinear", align_corners=False) +# return latent.to(in_dtype) * 2 - 1.0 + +# @property +# def spatial_compression_factor(self) -> int: +# return 8 - def get_pixel_num_frames(self, num_latent_frames: int) -> int: - if num_latent_frames == 1: - return 1 - assert ( - num_latent_frames % self.latent_chunk_duration == 0 - ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_chunk_duration}" - return num_latent_frames // self.latent_chunk_duration * self.pixel_chunk_duration class VideoJITTokenizer(BasePretrainedVideoTokenizer, JITVAE, VideoTokenizerInterface): @@ -735,4 +736,4 @@ def __init__(self, image_vae: Module, video_vae: Module, name: str, latent_ch: i ), f"video_vae should be an instance of VideoJITVAE, got {type(video_vae)}" # a hack to make the image_vae and video_vae share the same encoder and decoder self.image_vae.encoder = self.video_vae.encoder - self.image_vae.decoder = self.video_vae.decoder + self.image_vae.decoder = self.video_vae.decoder \ No newline at end of file diff --git a/cosmos_transfer1/utils/distributed.py b/cosmos_transfer1/utils/distributed.py index 68aac2f8..7250a825 100644 --- a/cosmos_transfer1/utils/distributed.py +++ b/cosmos_transfer1/utils/distributed.py @@ -20,6 +20,7 @@ import ctypes import functools import os +from contextlib import contextmanager from datetime import timedelta from typing import Any, Callable, Optional, TypeVar @@ -151,7 +152,36 @@ def wrapper(*args, **kwargs): # noqa: ANN202 return wrapper -def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DDP: +class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): + """This extends torch.nn.parallel.DistributedDataParallel with .training_step(). + + This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an coreModel such that + model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling + model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward -> + training_step), allowing us to preserve the function names and signatures. + """ + + def __init__(self, model: torch.nn.Module, *args, **kwargs): + super().__init__(model, *args, **kwargs) + + def training_step(self, *args, **kwargs) -> Any: + # Cache the original model.forward() method. + original_forward = self.module.forward + + def wrapped_training_step(*_args, **_kwargs): # noqa: ANN202 + # Unpatch immediately before calling training_step() because itself may want to call the real forward. + self.module.forward = original_forward + # The actual .training_step(). + return self.module.training_step(*_args, **_kwargs) + + # Patch the original_module's forward so we can redirect the arguments back to the real method. + self.module.forward = wrapped_training_step + # Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step(). + # Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed. + return self(*args, **kwargs) + + +def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DistributedDataParallel: """Wraps the model to enable data parallalism for training across multiple GPU devices. Args: @@ -171,7 +201,7 @@ def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> tor log.info("parallel_state not initialized, treating all GPUs equally for DDP") ddp_group = None - model = DDP( + model = DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, @@ -182,33 +212,37 @@ def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> tor ) return model -class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): - """This extends torch.nn.parallel.DistributedDataParallel with .training_step(). - - This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an coreModel such that - model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling - model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward -> - training_step), allowing us to preserve the function names and signatures. - """ - def __init__(self, model: torch.nn.Module, *args, **kwargs): - super().__init__(model, *args, **kwargs) +@contextmanager +def ddp_sync_grad(model, enabled): + r""" + Context manager to enable/disable gradient synchronizations across DDP processes for DDP model. + Modified from: + https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync + Note that this is incompatible with static_graph=True and will be an no-op if static_graph=True. - def training_step(self, *args, **kwargs) -> Any: - # Cache the original model.forward() method. - original_forward = self.module.forward + Within this context, gradients will be accumulated on module + variables, which will later be synchronized in the first + forward-backward pass exiting the context. - def wrapped_training_step(*_args, **_kwargs): # noqa: ANN202 - # Unpatch immediately before calling training_step() because itself may want to call the real forward. - self.module.forward = original_forward - # The actual .training_step(). - return self.module.training_step(*_args, **_kwargs) - - # Patch the original_module's forward so we can redirect the arguments back to the real method. - self.module.forward = wrapped_training_step - # Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step(). - # Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed. - return self(*args, **kwargs) + .. warning:: + The forward pass should be included inside the context manager, or + else gradients will still be synchronized. + """ + assert isinstance(model, torch.nn.Module) + if isinstance(model, DistributedDataParallel): + old_require_backward_grad_sync = model.require_backward_grad_sync + if model.static_graph and model.require_backward_grad_sync != enabled: + if model.show_sync_grad_static_graph_warning: + log.warning("DDP static_graph=True is incompatible with sync_grad(). Performance will be reduced.") + model.show_sync_grad_static_graph_warning = False + else: + model.require_backward_grad_sync = enabled + try: + yield + finally: + if isinstance(model, DistributedDataParallel): + model.require_backward_grad_sync = old_require_backward_grad_sync def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]: diff --git a/cosmos_transfer1/utils/trainer.py b/cosmos_transfer1/utils/trainer.py index cdb6af1f..bd1adc13 100644 --- a/cosmos_transfer1/utils/trainer.py +++ b/cosmos_transfer1/utils/trainer.py @@ -86,7 +86,7 @@ def __init__(self, config): if distributed.is_rank0(): # Print important environment variables and the effective config. log.info("Config:\n" + config.pretty_print(use_color=True)) - misc.print_environ_variables(["TORCH_HOME", "OUTPUT_ROOT"]) + misc.print_environ_variables(["OUTPUT_ROOT"]) # Set the random seed. If multi-GPU, different ranks are set with different seeds. misc.set_random_seed(seed=config.trainer.seed, by_rank=True) # Initialize cuDNN. diff --git a/examples/training_cosmos_transfer_7b.md b/examples/training_cosmos_transfer_7b.md index b54c77f5..9deb7eb0 100644 --- a/examples/training_cosmos_transfer_7b.md +++ b/examples/training_cosmos_transfer_7b.md @@ -191,8 +191,13 @@ Now we can start training! Run the following command to dry-run an example train ```bash export OUTPUT_ROOT=checkpoints # default value +# Training from scratch torchrun --nproc_per_node=1 -m cosmos_transfer1.diffusion.training.train --dryrun --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain + +# Post-train from our provided checkpoint (need to first split checkpoint into TP checkpoints as instructed above) +torchrun --nproc_per_node=1 -m cosmos_transfer1.diffusion.training.train --dryrun --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_posttrain ``` + Explanation of the command: - The trainer and the passed (master) config script will, in the background, load the detailed experiment configurations defined in `cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py`, and register the experiments configurations for all `hint_keys` (control modalities), covering both pretrain and post-train. We use [Hydra](https://hydra.cc/docs/intro/) for advanced configuration composition and overriding. @@ -201,30 +206,40 @@ Explanation of the command: - To customize your training, see `cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py` to understand how the detailed configs of the model, trainer, dataloader etc. are defined, and edit as needed. -- Removing the `--dryrun` will start a real training job. +- Removing the `--dryrun` and set `--nproc_per_node=8` will start a real training job on 8 GPUs: + + ```bash + torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain + ``` - Change the `experiment` value will decide which control modality model is trained, and whether it's pretrain or post-train. For example, replacing the experiment name in the command with `CTRL_7Bv1pt3_lvg_tp_121frames_control_input_depth_block3_posttrain` will post-train the DepthControl model from the downloaded checkpoint instead. - The checkpoints will be saved to `${OUTPUT_ROOT}/PROJECT/GROUP/NAME`. See the job config to understand how they are determined: -```python -# in cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py -config = LazyDict( - dict( - ... - job=dict( - project="cosmos_transfer1_pretrain", - group="CTRL_7Bv1_lvg", - name="CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain", - ), - ... + ```python + # In cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py + config = LazyDict( + dict( + ... + job=dict( + project="cosmos_transfer1_pretrain", + group="CTRL_7Bv1_lvg", + name="CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain", + ), + ... + ) ) -) -``` + ``` -During the training, the checkpoints will be saved in the below structure. -``` -checkpoints/cosmos_transfer1_pretrain/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain/checkpoints/ -├── iter_{NUMBER}_reg_model.pt -├── iter_{NUMBER}_ema_model.pt -``` + During the training, the checkpoints will be saved in the below structure. Since we use TensorParallel across 8 GPUs, 8 checkpoints will be saved each time. + + ``` + checkpoints/cosmos_transfer1_pretrain/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain/checkpoints/ + ├── iter_{NUMBER}.pt # "master" checkpoint, saving metadata only + ├── iter_{NUMBER}_model_mp_0.pt # real TP checkpoints + ├── iter_{NUMBER}_model_mp_1.pt + ├── ... + ├── iter_{NUMBER}_model_mp_7.pt + ``` + +- Since the `experiment` is uniquely associated with its checkpoint directory, rerunning the same training command after an unexpected interruption will automatically resume from the latest saved checkpoint. \ No newline at end of file diff --git a/scripts/convert_ckpt_fsdp_to_tp.py b/scripts/convert_ckpt_fsdp_to_tp.py index 3975a1dd..d0981dc2 100644 --- a/scripts/convert_ckpt_fsdp_to_tp.py +++ b/scripts/convert_ckpt_fsdp_to_tp.py @@ -78,7 +78,7 @@ def convert_fsdp_to_tp(path_in: str, path_out: str) -> None: Args: path_in: Path to input checkpoint (without _reg_model.pt suffix) - path_out: Path for output checkpoint (without _mp_X.pt suffix) + path_out: Path for output checkpoint (without _model_mp_X.pt suffix) tp_size: Number of tensor parallel partitions verbose: Whether to show progress bar @@ -105,7 +105,7 @@ def convert_fsdp_to_tp(path_in: str, path_out: str) -> None: easy_io.dump({'grad_scaler': {}, 'iteration': 0}, f"{path_out}.pt") for i in tqdm(range(TP_SIZE)): state_dict = {"model": state_dicts[i]} - easy_io.dump(state_dict, f"{path_out}_mp_{i}.pt") + easy_io.dump(state_dict, f"{path_out}_model_mp_{i}.pt") if __name__ == "__main__": @@ -116,9 +116,9 @@ def convert_fsdp_to_tp(path_in: str, path_out: str) -> None: python convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control.pt This will save the Tensor Parallel (TP) checkpoints as 8 files in the same directory: - checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_mp_0.pt + checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_model_mp_0.pt ... - checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_mp_7.pt + checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_model_mp_7.pt ''' if len(sys.argv) != 2: print("Usage: python convert_ckpt_fsdp_to_tp.py ") diff --git a/scripts/download_diffusion_example_data.py b/scripts/download_diffusion_example_data.py new file mode 100644 index 00000000..70bd4ba2 --- /dev/null +++ b/scripts/download_diffusion_example_data.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os + +import ffmpeg +from pytubefix import YouTube + +"""example command +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_diffusion_example_data.py --dataset_path datasets/hdvila --N_videos 128 --do_download --do_clip +""" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Download example (hdvila) data for posttraining") + parser.add_argument("--dataset_path", type=str, default="datasets/hdvila", help="Root path to the dataset") + parser.add_argument("--N_videos", type=int, default=128, help="Number of videos to download") + parser.add_argument("--do_download", action="store_true", help="Download the videos") + parser.add_argument("--do_clip", action="store_true", help="Clip the videos") + return parser.parse_args() + + +def convert_time_to_seconds(time_str) -> int: + h, m, s = map(float, time_str.split(":")) + ms = int(time_str.split(".")[-1]) if "." in time_str else 0 + return int(h * 3600 + m * 60 + s) + ms / 1000 + + +def download_data(args) -> None: + urls_set = set() + download_count = 0 + + videos_orig_dir = os.path.join(args.dataset_path, "videos_orig") + os.makedirs(videos_orig_dir, exist_ok=True) + videos_dir = os.path.join(args.dataset_path, "videos") + os.makedirs(videos_dir, exist_ok=True) + metas_dir = os.path.join(args.dataset_path, "metas") + os.makedirs(metas_dir, exist_ok=True) + + hdvila_jsonl_path = os.path.join(args.dataset_path, "hdvila-100M.jsonl") + with open(hdvila_jsonl_path, "r") as fp: + for line in fp: + json_object = json.loads(line) + url = json_object["url"] + if url not in urls_set: # download videos with unique urls + yt = YouTube(json_object["url"]) + try: + # Download a video + yt.streams.get_highest_resolution().download( + output_path=videos_orig_dir, filename=json_object["video_id"] + ".mp4" + ) + download_count += 1 + urls_set.add(url) + print(f"Downloaded videos: {download_count}/{args.N_videos}") + + # Save metadata - caption and whole metadata + meta_txt_name = os.path.join(metas_dir, json_object["clip_id"].replace(".mp4", ".txt")) + with open(meta_txt_name, "w") as fp: + fp.write(json_object["caption"]) + meta_json_name = os.path.join(metas_dir, json_object["clip_id"].replace(".mp4", ".json")) + with open(meta_json_name, "w") as fp: + json.dump(json_object, fp) + except Exception as e: + print(e) + continue + + if len(urls_set) >= args.N_videos: + break + + +def clip_data(args) -> None: + videos_orig_dir = os.path.join(args.dataset_path, "videos_orig") + videos_dir = os.path.join(args.dataset_path, "videos") + os.makedirs(videos_dir, exist_ok=True) + metas_dir = os.path.join(args.dataset_path, "metas") + + metas_list = [ + os.path.join(metas_dir, filename) for filename in sorted(os.listdir(metas_dir)) if filename.endswith(".json") + ] + videos_orig_list = [ + os.path.join(videos_orig_dir, filename) + for filename in sorted(os.listdir(videos_orig_dir)) + if filename.endswith(".mp4") + ] + + for meta_filename, video_orig_filename in zip(metas_list, videos_orig_list): + with open(meta_filename, "r") as fp: + metadata = json.load(fp) + + # Convert time strings to seconds + start_time = convert_time_to_seconds(metadata["span_start"]) + end_time = convert_time_to_seconds(metadata["span_end"]) + # Clip the video + clip_name = os.path.join(videos_dir, metadata["clip_id"]) + ffmpeg.input(video_orig_filename, ss=start_time, t=end_time - start_time).output(clip_name).run() + + +def main(args) -> None: + if args.do_download: + download_data(args) + if args.do_clip: + clip_data(args) + + +if __name__ == "__main__": + args = parse_args() + main(args)