Skip to content

Commit f576c93

Browse files
authored
Merge pull request #712 from westfish/opensora_video_save
update opensora video save method
2 parents e2e9a79 + b126bd4 commit f576c93

3 files changed

Lines changed: 37 additions & 21 deletions

File tree

ppdiffusers/examples/Open-Sora/dataset/utils.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from collections.abc import Sequence
1818

1919
import cv2
20+
import imageio
2021
import numpy as np
2122
import paddle
2223
import pandas as pd
@@ -247,40 +248,53 @@ def read_from_path(path, image_size, transform_name="center"):
247248
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
248249
return read_image_from_path(path, image_size=image_size, transform_name=transform_name)
249250

250-
251251
def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1.0, 1.0)):
252252
"""
253+
Saves a video sample from a tensor without using OpenCV.
254+
253255
Args:
254-
x (Tensor): shape [C, T, H, W]
256+
x (Tensor): Tensor of shape [C, T, H, W].
257+
fps (int, optional): Frames per second for the saved video. Defaults to 8.
258+
save_path (str, optional): Path to save the video. If None, a default path is used.
259+
normalize (bool, optional): Whether to normalize the tensor values. Defaults to True.
260+
value_range (tuple, optional): Tuple specifying the (min, max) range for normalization. Defaults to (-1.0, 1.0).
261+
262+
Returns:
263+
str: The path where the video is saved.
255264
"""
256-
assert x.ndim == 4
265+
assert x.ndim == 4, f"Expected tensor with 4 dimensions [C, T, H, W], but got {x.ndim} dimensions."
266+
267+
if save_path is None:
268+
raise ValueError("save_path must be provided.")
257269

258270
save_path += ".mp4"
259-
if normalize:
260271

272+
if normalize:
261273
low, high = paddle.to_tensor(value_range, dtype="float32")
262-
x.clip_(min=low, max=high)
263-
x.subtract_(low).divide_(max(high - low, 1e-5))
274+
x = x.clip(min=low, max=high)
275+
x = (x - low) / paddle.maximum(high - low, paddle.to_tensor(1e-5))
264276

277+
# Scale to [0, 255] and convert to uint8
265278
video_data = (
266-
x.multiply(y=paddle.to_tensor(255.0, dtype="float32"))
267-
.add_(paddle.to_tensor(0.5, dtype="float32"))
268-
.clip_(0, 255)
279+
x.multiply(paddle.to_tensor(255.0, dtype="float32"))
280+
.add(paddle.to_tensor(0.5, dtype="float32")) # For rounding
281+
.clip(0, 255)
269282
)
270283
video_data = video_data.transpose([1, 2, 3, 0])
271-
video_data = video_data.numpy()
272-
video_data = video_data.astype(np.uint8)
284+
video_data = video_data.numpy().astype(np.uint8)
273285

274286
frames, height, width, channels = video_data.shape
275287

276-
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
277-
out = cv2.VideoWriter(save_path, fourcc, fps, (width, height))
278-
279-
for i in range(frames):
280-
frame = cv2.cvtColor(video_data[i], cv2.COLOR_RGB2BGR)
281-
out.write(frame)
288+
# Initialize the video writer using imageio
289+
writer = imageio.get_writer(save_path, fps=fps, codec='libx264', format='mp4')
282290

283-
out.release()
291+
try:
292+
for i in range(frames):
293+
frame = video_data[i]
294+
# Ensure frame is in RGB format
295+
writer.append_data(frame)
296+
finally:
297+
writer.close()
284298

285299
print(f"Saved to {save_path}")
286-
return save_path
300+
return save_path

ppdiffusers/examples/Open-Sora/requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ pyarrow
55
pyav
66
tqdm
77
beartype
8-
pandarallel
8+
pandarallel
9+
imageio
10+
imageio-ffmpeg

ppdiffusers/examples/Open-Sora/schedulers/iddpm/gaussian_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def __init__(self, *, betas: paddle.Tensor, model_mean_type: str, model_var_type
184184
# Use float64 for accuracy.
185185
self.betas = betas
186186
assert len(self.betas.shape) == 1, "betas must be 1-D"
187-
assert (self.betas > 0).all() and (self.betas <= 1).all()
187+
assert (self.betas > 0).all() and (self.betas <= 1).all(), self.betas
188188

189189
self.num_timesteps = int(betas.shape[0])
190190

0 commit comments

Comments
 (0)