Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def encode_images(self, images, filepaths, load_from_cache=True):
)

# For Wan, get the raw parameters (32 channels)
if StateTracker.get_model_family() in ["wan", "cosmos2image"]:
if StateTracker.get_model_family() in ["wan"]:
if hasattr(latents_uncached, "latent_dist"):
# This is 32 channels (mu + logvar)
latents_uncached = latents_uncached.latent_dist.parameters
Expand Down
168 changes: 89 additions & 79 deletions helpers/models/cosmos/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch, os, logging
import torch.nn.functional as F
from helpers.models.cosmos.pipeline import Cosmos2TextToImagePipeline
from helpers.models.cosmos.transformer import CosmosTransformer3DModel
from helpers.models.common import (
Expand All @@ -24,7 +25,7 @@

class Cosmos2Image(VideoModelFoundation):
NAME = "Cosmos (T2I)"
PREDICTION_TYPE = PredictionTypes.FLOW_MATCHING
PREDICTION_TYPE = PredictionTypes.SAMPLE
MODEL_TYPE = ModelTypes.TRANSFORMER
AUTOENCODER_CLASS = AutoencoderKLWan
LATENT_CHANNEL_COUNT = 16
Expand Down Expand Up @@ -149,92 +150,101 @@ def pre_vae_encode_transform_sample(self, sample):

return sample

def model_predict(self, prepared_batch):
def prepare_batch(self, batch: dict, state: dict) -> dict:
"""
1. Move tensors to the accelerator device / dtype.
2. Draw σ from the log-uniform EDM schedule.
3. Add additive Gaussian noise xₜ = x₀ + σ ε.
4. Store `sigmas` (broadcast shape B×1×1×1×1) and `noisy_latents`.
Everything else (prompt embeds, masks, etc.) follows the base
implementation.
"""
Perform model prediction for training.
if not batch:
return batch

Args:
prepared_batch: Dictionary containing batch data
# ---------- move prompt embeds & latents to device --------------
target_kwargs = {"device": self.accelerator.device,
"dtype": self.config.weight_dtype}

Returns:
Dictionary containing model prediction
"""
if prepared_batch["noisy_latents"].shape[1] != 32:
raise ValueError(
f"Cosmos T2I requires a latent size of 32 channels. "
f"Batch received: {prepared_batch['noisy_latents'].shape}"
)
# we have to split the mu and logvar channels on the noisy latents
if prepared_batch["noisy_latents"].shape[1] != 16:
# just slice the first 16 channels and discard the rest
prepared_batch["noisy_latents"] = prepared_batch["noisy_latents"].narrow(
1, 0, 16
)
# slice also the target latents
prepared_batch["latents"] = prepared_batch["latents"].narrow(1, 0, 16)
# and the noise
prepared_batch["noise"] = prepared_batch["noise"].narrow(1, 0, 16)

# For T2I, we use single frame (num_frames=1)
batch_size, channels, num_frames, height, width = prepared_batch[
"noisy_latents"
].shape

# Create padding mask
padding_mask = torch.zeros(
1,
1,
height,
width,
device=prepared_batch["noisy_latents"].device,
dtype=prepared_batch["noisy_latents"].dtype,
)
if batch.get("prompt_embeds") is not None:
batch["encoder_hidden_states"] = batch["prompt_embeds"].to(**target_kwargs)

# Prepare timesteps - Cosmos uses a different timestep format
timesteps = prepared_batch["timesteps"]
current_sigma = timesteps # Assuming timesteps are sigmas
current_t = current_sigma / (current_sigma + 1)
timestep = current_t.to(dtype=prepared_batch["noisy_latents"].dtype)

# Model forward pass
model_pred = self.model(
hidden_states=prepared_batch["noisy_latents"].to(
device=self.accelerator.device,
dtype=self.config.weight_dtype,
),
timestep=timestep,
encoder_hidden_states=prepared_batch["encoder_hidden_states"].to(
device=self.accelerator.device,
dtype=self.config.weight_dtype,
),
padding_mask=padding_mask,
return_dict=False,
)[0]
# return the split mu and logvar channels
model_pred = model_pred[:, :16, :, :, :] # Keep only the first
return {
"model_prediction": model_pred,
}
latents = batch["latent_batch"].to(**target_kwargs) # clean x0
batch["latents"] = latents

def prepare_flow_matching_params(self, batch_size: int, device: torch.device):
"""
Prepare flow matching specific parameters.
# ---------- plain Gaussian noise ε ------------------------------
noise = torch.randn_like(latents)
batch["noise"] = noise
batch["input_noise"] = noise # no extra perturbation

Args:
batch_size: Current batch size
device: Device to create tensors on
# ---------- draw σ and form x_t ---------------------------------
bsz = latents.size(0)
sigmas = self.prepare_edm_sigmas(bsz, self.accelerator.device)["sigmas"] # (B,)
sigmas_exp = sigmas.view(-1, 1, 1, 1, 1) # B×1×1×1×1

Returns:
Dictionary of flow matching parameters
"""
# Sample sigmas according to Cosmos schedule
sigmas = torch.rand(batch_size, device=device)
sigmas = self.sigma_min + (self.sigma_max - self.sigma_min) * sigmas
batch["sigmas"] = sigmas_exp
batch["timesteps"] = sigmas # unused but kept for API
batch["noisy_latents"] = latents + sigmas_exp * noise # x_t

return {
"sigmas": sigmas,
"sigma_data": self.sigma_data,
}
# ---------- any ControlNet / mask specific tweaks ---------------
batch = self.prepare_batch_conditions(batch=batch, state=state)
return batch

def model_predict(self, prepared_batch):
xt = prepared_batch["noisy_latents"]
sigmas = prepared_batch["sigmas"].view(-1, 1, 1, 1, 1) # B×1×1×1×1
B, _, _, H, W = xt.shape
device = self.accelerator.device
dtype = self.config.weight_dtype

inv = 1.0 / (sigmas + 1.0) # == c_in == c_skip
cout = -sigmas * inv

latent_in = xt * inv
timestep = (sigmas / (sigmas + 1)).view(B).to(dtype=dtype) # == current_t

pad_mask = torch.zeros(B, 1, H, W, device=device, dtype=latent_in.dtype)
r_pred = self.model(
hidden_states = latent_in.to(dtype),
timestep = timestep,
encoder_hidden_states = prepared_batch["encoder_hidden_states"].to(dtype),
padding_mask = pad_mask,
return_dict = False,
)[0] # transformer output

x0_pred = inv * xt + cout * r_pred.float() # behaviour identical to NVIDIA loop
return {"model_prediction": x0_pred}

def loss(self, prepared_batch, model_output, apply_conditioning_mask=True):
x0 = prepared_batch["latents"].float()
x0_pred = model_output["model_prediction"].float()
sigmas = prepared_batch["sigmas"]

w = (sigmas ** 2 + self.sigma_data ** 2) / (sigmas * self.sigma_data) ** 2
while w.ndim < x0.ndim:
w = w.unsqueeze(-1)

loss = F.mse_loss(x0_pred, x0, reduction="none") * w

if apply_conditioning_mask:
ctype = prepared_batch.get("conditioning_type")
if ctype == "mask":
m = prepared_batch["conditioning_pixel_values"][:, :1]
m = torch.nn.functional.interpolate(m, size=loss.shape[2:], mode="area")
loss *= (m / 2 + 0.5)
elif ctype == "segmentation":
m = prepared_batch["conditioning_pixel_values"]
m = torch.sum(m, dim=1, keepdim=True) / 3
m = torch.nn.functional.interpolate(m, size=loss.shape[2:], mode="area")
loss *= ((m / 2 + 0.5) > 0).to(loss.dtype)

return loss.mean()

def prepare_edm_sigmas(self, bsz: int, device: torch.device) -> torch.Tensor:
log_min, log_max = map(torch.log10, (torch.tensor(self.sigma_min),
torch.tensor(self.sigma_max)))
u = torch.rand(bsz, device=device)
return (10.0 ** (log_min + (log_max - log_min) * u)).to(device)

def check_user_config(self):
"""
Expand Down