Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions cosmos_transfer1/diffusion/inference/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def demo(cfg, control_inputs):
misc.set_random_seed(cfg.seed)

device_rank = 0
process_group = None
if cfg.num_gpus > 1:
from megatron.core import parallel_state

Expand Down Expand Up @@ -226,15 +227,9 @@ def demo(cfg, control_inputs):
canny_threshold=cfg.canny_threshold,
upsample_prompt=cfg.upsample_prompt,
offload_prompt_upsampler=cfg.offload_prompt_upsampler,
process_group=process_group,
)

if cfg.num_gpus > 1:
pipeline.model.model.net.enable_context_parallel(process_group)
pipeline.model.model.base_model.net.enable_context_parallel(process_group)
if hasattr(pipeline.model.model, "hint_encoders"):
pipeline.model.model.hint_encoders.net.enable_context_parallel(process_group)

# Handle multiple prompts if prompt file is provided
if cfg.batch_input_path:
log.info(f"Reading batch inputs from path: {cfg.batch_input_path}")
prompts = read_prompts_from_file(cfg.batch_input_path)
Expand Down
10 changes: 10 additions & 0 deletions cosmos_transfer1/diffusion/inference/world_generation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
canny_threshold: str = "medium",
upsample_prompt: bool = False,
offload_prompt_upsampler: bool = False,
process_group: torch.distributed.ProcessGroup | None = None,
):
"""Initialize diffusion world generation pipeline.

Expand All @@ -127,6 +128,7 @@ def __init__(
canny_threshold: Threshold for edge detection
upsample_prompt: Whether to upsample prompts using prompt upsampler model
offload_prompt_upsampler: Whether to offload prompt upsampler after use
process_group: Process group for distributed training
"""
self.num_input_frames = num_input_frames
self.control_inputs = control_inputs
Expand All @@ -138,6 +140,7 @@ def __init__(
self.prompt_upsampler = None
self.upsampler_hint_key = None
self.hint_details = None
self.process_group = process_group

self.model_name = MODEL_NAME_DICT[checkpoint_name]
self.model_class = MODEL_CLASS_DICT[checkpoint_name]
Expand Down Expand Up @@ -296,6 +299,13 @@ def _load_network(self):
) # , weights_only=True)
non_strict_load_model(self.model.model, net_state_dict)

if self.process_group is not None:
self.model.model.net.enable_context_parallel(self.process_group)
self.model.model.base_model.net.enable_context_parallel(self.process_group)
if hasattr(self.model.model, "hint_encoders"):
self.model.model.hint_encoders.net.enable_context_parallel(self.process_group)


def _load_tokenizer(self):
load_tokenizer_model(self.model, f"{self.checkpoint_dir}/{COSMOS_TOKENIZER_CHECKPOINT}")

Expand Down