-
Notifications
You must be signed in to change notification settings - Fork 97
Bug in temporal_sample (video2world.py): Error with 1-frame input & failure to generate 12-frame output; Fix via frame repetition #173
Description
When running the test code for Video2World Post-training for Action-conditioned Video Prediction with the following command:
python examples/video2world_action.py --model_size 2B --dit_path "/home/CONNECT/yfang870/cosmos/cosmos-predict2/checkpoints/nvidia/Cosmos-Predict2-2B-Sample-Action-Conditioned/model-480p-4fps.pt" --input_video datasets/bridge/videos/test/13/rgb.mp4 --input_annotation datasets/bridge/annotation/test/13.json --num_conditional_frames 1 --save_path output/generated_video.mp4 --guidance 0 --seed 0 --disable_guardrail --disable_prompt_refiner
There is a function in cosmos_predict2/pipelines/video2world.py:
def temporal_sample(video: torch.Tensor, expected_length: int) -> torch.Tensor:
# sample consecutive video frames to match expected_length
original_length = video.shape[2]
if original_length != expected_length:
# video in [B C T H W] format
start_frame = np.random.randint(0, original_length - expected_length)
end_frame = start_frame + expected_length
video = video[:, :, start_frame:end_frame, :, :]
return video
expected_length = self.tokenizer.get_pixel_num_frames(self.config.state_t)
original_length = data_batch[input_key].shape[2]
if original_length != expected_length:
data_batch[input_key] = temporal_sample(data_batch[input_key], expected_length)The purpose of this function is to sample frames if the input video frames are longer than the expected generated length.
However, during testing, when the input video has only 1 frame, running the code directly will cause an error at start_frame = np.random.randint(0, original_length - expected_length). If we comment out the temporal_sample function, the final returned video will have 1 frame, which leads to the model input noise also being 1 frame in length. As a result, the output video will only have 1 frame and cannot generate a 12-frame video according to the action length.
Therefore, we can perform operations like repeating or padding to make the video frame length match the expected length, ensuring that the generated video has a consistent length.
The final code modification is as follows, changing temporal_sample to:
def temporal_sample(video: torch.Tensor, expected_length: int) -> torch.Tensor:
original_length = video.shape[2]
if original_length == expected_length:
return video
elif original_length == 1:
video = video.repeat(1, 1, expected_length, 1, 1)
return video
else:
start_frame = np.random.randint(0, original_length - expected_length + 1)
end_frame = start_frame + expected_length
video = video[:, :, start_frame:end_frame, :, :]
return video