Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
2efa849
ClippingSubsampler rewrite and bug fixes
MattUnderscoreZhang Jan 18, 2024
a5c9649
More refactoring of ClippingSubsampler, plus a fix to _get_clip_inter…
MattUnderscoreZhang Jan 18, 2024
2cb5854
Finished refactoring ClippingSubsampler
MattUnderscoreZhang Jan 18, 2024
6106f62
Merge branch 'clipping_subsampler_rewrite' into all_fixes
MattUnderscoreZhang Jan 18, 2024
5d03b72
Final code changes
MattUnderscoreZhang Jan 19, 2024
47c7d64
Added docstrings
MattUnderscoreZhang Jan 19, 2024
5aa84d4
Passed tests and linting
MattUnderscoreZhang Jan 19, 2024
140e1ab
Made type annotations consistent with Python 3.8
MattUnderscoreZhang Jan 19, 2024
077ca27
More annotation fixes
MattUnderscoreZhang Jan 19, 2024
32fa4ea
The Python 3.8 annotation needs a lot of hand-holding, it seems
MattUnderscoreZhang Jan 19, 2024
5a8957f
Pylint has to cut it out, I swear to God
MattUnderscoreZhang Jan 19, 2024
f0f0168
No real change, just relauching unit tests which failed due to connec…
MattUnderscoreZhang Jan 19, 2024
f5d7c85
Merge branch 'main' into clipping_subsampler_refactor
iejMac Jan 19, 2024
388f51a
Merge branch 'main' into clipping_subsampler_refactor
rom1504 Jan 21, 2024
5101379
Merge remote-tracking branch 'origin/main' into clipping_subsampler_r…
MattUnderscoreZhang Jan 22, 2024
1df88dd
Linting issue
MattUnderscoreZhang Jan 22, 2024
226fba3
Another linting issue
MattUnderscoreZhang Jan 22, 2024
8ed5074
Separated per-shard code from code that should only be executed once
MattUnderscoreZhang Jan 24, 2024
e862eaa
Pulled ShardStatus parameters into their own data type
MattUnderscoreZhang Jan 24, 2024
d158106
Cleaned up shard processing error handling
MattUnderscoreZhang Jan 24, 2024
5cd53a9
Cleaned up code
MattUnderscoreZhang Jan 24, 2024
ffe0e71
Bug fixes
MattUnderscoreZhang Jan 24, 2024
2c7daf8
Formatting
MattUnderscoreZhang Jan 24, 2024
ac5a35b
Fixed linting issues
MattUnderscoreZhang Jan 24, 2024
5222f39
Fixing more damn linting
MattUnderscoreZhang Jan 24, 2024
6dc8991
Added a missing docstring
MattUnderscoreZhang Jan 24, 2024
6cbb43f
Unified SubsetWorker and DownloadWorker code
MattUnderscoreZhang Jan 24, 2024
d5f3b19
Bug fixes
MattUnderscoreZhang Jan 24, 2024
efceb33
Merge branch 'main' into download_worker_refactoring
MattUnderscoreZhang Jan 24, 2024
f33ed6c
Linting
MattUnderscoreZhang Jan 24, 2024
fb89ced
Linting again
MattUnderscoreZhang Jan 24, 2024
fca3332
Forgot a docstring
MattUnderscoreZhang Jan 24, 2024
6ba72af
Merge branch 'main' into download_worker_refactoring
MattUnderscoreZhang Jan 26, 2024
afc432b
Removed unnecessary manual thread handling
MattUnderscoreZhang Jan 26, 2024
c7b38af
Removed unused import
MattUnderscoreZhang Jan 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions video2dataset/subsamplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@
from .optical_flow_subsampler import OpticalFlowSubsampler
from .whisper_subsampler import WhisperSubsampler
from .caption_subsampler import CaptionSubsampler

from .subsampler import Subsampler
4 changes: 4 additions & 0 deletions video2dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ class EncodeFormats(TypedDict, total=False):
class Streams(TypedDict, total=False):
video: List[bytes]
audio: List[bytes]


# TODO: make more structured
Metadata = dict
304 changes: 96 additions & 208 deletions video2dataset/workers/download_worker.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,15 @@
"""the downloader module handles the downloading"""

import fsspec
import math
import time
from multiprocessing.pool import ThreadPool
import pyarrow as pa
import time
import traceback

import fsspec

from multiprocessing.pool import ThreadPool
from threading import Semaphore
from typing import List, Any
import numpy as np
from typing import cast

from video2dataset.data_reader import VideoDataReader
from video2dataset.logger import CappedCounter
from video2dataset.logger import write_stats
from video2dataset.subsamplers import (
ClippingSubsampler,
CutDetectionSubsampler,
FrameSubsampler,
FFProbeSubsampler,
NoOpSubsampler,
ResolutionSubsampler,
AudioRateSubsampler,
)
from video2dataset.workers.worker import ShardStatus, Streams, get_subsamplers, process_sample


def compute_key(key, shard_id, oom_sample_per_shard, oom_shard_count):
Expand Down Expand Up @@ -52,252 +38,154 @@ def __init__(
self.save_caption = save_caption
self.output_folder = output_folder
self.column_list = column_list
self.encode_formats = encode_formats
self.input_encode_formats = encode_formats
self.config = config

self.data_reader = VideoDataReader(encode_formats, tmp_dir, config["reading"])

self.clipping_subsampler = ClippingSubsampler(
5, # oom_clip_count
self.url_indice = self.column_list.index("url")
self.caption_indice = self.column_list.index("caption") if "caption" in self.column_list else None
self.oom_sample_per_shard = math.ceil(math.log10(self.config["storage"]["number_sample_per_shard"]))
self.subsamplers, self.output_encode_formats = get_subsamplers(
config,
encode_formats,
**self.config["subsampling"].get("ClippingSubsampler", {"args": {}})["args"],
do_clipping=("clips" in self.column_list),
)
need_keyframes = self.clipping_subsampler.precision == "keyframe_adjusted"

self.ffprobe_subsampler = None
if "FFProbeSubsampler" in self.config["subsampling"] or need_keyframes:
self.ffprobe_subsampler = FFProbeSubsampler(
**self.config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"]
)
self.ffprobe_subsampler.extract_keyframes |= need_keyframes

self.cut_detector = None
self.cuts_are_clips = False
if "CutDetectionSubsampler" in self.config["subsampling"]:
if "args" in self.config["subsampling"]["CutDetectionSubsampler"]:
self.cut_detector = CutDetectionSubsampler(
**self.config["subsampling"]["CutDetectionSubsampler"]["args"]
)
self.cuts_are_clips = self.config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False)

self.noop_subsampler = NoOpSubsampler()

video_subsamplers: List[Any] = []
if "ResolutionSubsampler" in self.config["subsampling"]:
video_subsamplers.append(ResolutionSubsampler(**self.config["subsampling"]["ResolutionSubsampler"]["args"]))
if "FrameSubsampler" in self.config["subsampling"]:
video_subsamplers.append(FrameSubsampler(**self.config["subsampling"]["FrameSubsampler"]["args"]))

audio_subsamplers: List[Any] = []
if "AudioRateSubsampler" in self.config["subsampling"]:
audio_subsamplers.append(AudioRateSubsampler(**self.config["subsampling"]["AudioRateSubsampler"]["args"]))

self.subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers}

def __call__(
self,
row,
):
try:
self.download_shard(row)
shard_file, shard_id = row
self.process_shard(shard_file, shard_id)
return (True, row)
except Exception as err: # pylint: disable=broad-except
traceback.print_exc()
print(f"shard {row[0]} failed with error {err}")
return (False, row)

def download_shard(
def get_shard_processors(
self,
row,
shard_file: str,
shard_id: int,
):
"""Function to start an video downloading in one process"""

# shard_id, shard_file = row
shard_file, shard_id = row
start_time = time.time()
"""Get objects for loading and writing data"""

fs, shard_path = fsspec.core.url_to_fs(shard_file)
print(shard_path)
with fs.open(shard_path, "rb") as f:
df = pa.ipc.open_file(f).read_all()
schema = df.schema
schema = df.schema
schema = (
schema.append(pa.field("key", pa.string()))
.append(pa.field("status", pa.string()))
.append(pa.field("error_message", pa.string()))
)

shard_sample_writer = self.sample_writer_class(
shard_id,
self.output_folder,
self.save_caption,
self.config["storage"]["oom_shard_count"],
schema,
self.output_encode_formats,
)
pydict = df.select(self.column_list).to_pydict()
shard_to_dl = list(enumerate(zip(*(pydict[col] for col in self.column_list))))
del pydict
del df

status_dict = CappedCounter()

count = len(shard_to_dl)
successes = 0
failed = {
"failed_to_download": 0,
"failed_to_subsample": 0,
}
bytes_downloaded = 0
url_indice = self.column_list.index("url")
caption_indice = self.column_list.index("caption") if "caption" in self.column_list else None
key_url_list = [(key, x[url_indice]) for key, x in shard_to_dl]
def rm_shard_path():
fs.rm(shard_path)

semaphore = Semaphore(self.config["distribution"]["thread_count"])
return shard_sample_writer, shard_to_dl, rm_shard_path

def data_generator():
for e in key_url_list:
semaphore.acquire() # pylint: disable=(consider-using-with)
yield e
def process_shard(
self,
shard_file: str,
shard_id: int,
):
"""Function to start an video downloading in one process"""

loader = data_generator()
start_time = time.time()
shard_sample_writer, shard_to_dl, rm_shard_path = self.get_shard_processors(shard_file, shard_id)
shard_status = ShardStatus(count=len(shard_to_dl))

# The subsamplers might change the output format, so we need to update the writer
writer_encode_formats = self.encode_formats.copy()
if self.subsamplers["audio"]:
writer_encode_formats["audio"] = self.subsamplers["audio"][0].encode_formats["audio"]
if self.subsamplers["video"]:
writer_encode_formats["video"] = self.subsamplers["video"][0].encode_formats["video"]
def data_generator():
for key_and_url in [(key, x[self.url_indice]) for key, x in shard_to_dl]:
yield key_and_url

# give schema to writer
sample_writer = self.sample_writer_class(
shard_id,
self.output_folder,
self.save_caption,
self.config["storage"]["oom_shard_count"],
schema,
writer_encode_formats,
)
oom_sample_per_shard = math.ceil(math.log10(self.config["storage"]["number_sample_per_shard"]))
data_reader_call_param_generator = data_generator()

with ThreadPool(self.config["distribution"]["thread_count"]) as thread_pool:
for key, streams, yt_meta_dict, error_message in thread_pool.imap_unordered(
for key, streams, yt_meta_dict, shard_status.error_message in thread_pool.imap_unordered(
self.data_reader, # pylint: disable=(unnecessary-lambda)
loader,
data_reader_call_param_generator,
):
try:
_, sample_data = shard_to_dl[key]
str_key = compute_key(
key, shard_id, oom_sample_per_shard, self.config["storage"]["oom_shard_count"]
key, shard_id, self.oom_sample_per_shard, self.config["storage"]["oom_shard_count"]
)
meta = {
caption = sample_data[self.caption_indice] if self.caption_indice is not None else None
metadata = {
**{self.column_list[i]: sample_data[i] for i in range(len(self.column_list))},
"key": str_key,
"status": None,
"error_message": error_message,
"error_message": shard_status.error_message,
"yt_meta_dict": yt_meta_dict,
}

if error_message is not None:
print(error_message)
if "[youtube]" in error_message: # video-specific error, remove videoID
error_message = "ERROR: [youtube]:" + error_message.split(":")[-1]
raise ValueError("failed_to_download")

for stream in streams.values():
bytes_downloaded += len(stream)
for mod in streams:
streams[mod] = [streams[mod]]

if self.ffprobe_subsampler is not None:
streams, meta, error_message = self.ffprobe_subsampler(streams, meta)
if error_message is not None:
raise ValueError("failed_to_subsample")

if self.config["storage"]["captions_are_subtitles"]: # create clips
# all langs have same start and end times
subtitles = meta["yt_meta_dict"]["subtitles"][list(meta["yt_meta_dict"]["subtitles"].keys())[0]]
meta["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles]
elif self.cut_detector is not None: # apply cut detection to get clips
streams, cuts, error_message = self.cut_detector(streams)

if error_message is not None:
raise ValueError("failed_to_subsample")

meta["cuts"] = cuts

if self.cuts_are_clips:
cuts = meta["cuts"]["cuts_original_fps"]
native_fps = meta["cuts"]["original_fps"]
meta["clips"] = (np.array(cuts) / native_fps).tolist()

# 1 video -> many videos (either clipping or noop which does identity broadcasting)
broadcast_subsampler = (
self.clipping_subsampler
if (
"clips" in self.column_list
or self.config["storage"]["captions_are_subtitles"]
or self.cuts_are_clips
)
else self.noop_subsampler
)
subsampled_streams, metas, error_message = broadcast_subsampler(streams, meta)

for modality in subsampled_streams:
for modality_subsampler in self.subsamplers[modality]:
subsampled_streams, metas, error_message = modality_subsampler(subsampled_streams, metas)

if error_message is not None:
meta["clips"] = []
raise ValueError("failed_to_subsample")

successes += 1
status = "success"
status_dict.increment(status)
subsampled_streams_list = [
dict(zip(subsampled_streams, s)) for s in zip(*subsampled_streams.values())
]
for subsampled_streams, meta in zip(subsampled_streams_list, metas):
meta["status"] = status

text_caption = sample_data[caption_indice] if caption_indice is not None else None
if self.config["storage"]["captions_are_subtitles"]:
text_caption = meta.get("clip_subtitles")[0]["lines"]

sample_writer.write(
subsampled_streams,
meta["key"],
text_caption,
meta,
)
except Exception as err: # pylint: disable=broad-except
status = str(err)
if status.startswith("failed_to_"):
failed[status] += 1
status_dict.increment(error_message)
meta["status"] = status
meta["error_message"] = error_message
sample_writer.write(
{},
str_key,
sample_data[caption_indice] if caption_indice is not None else None,
meta,
)
semaphore.release()
else:
traceback.print_exc()
print(f"Sample {key} failed to download: {err}")
traceback.print_exc()
print(f"Sample {key} failed to download: {err}")
return

semaphore.release()

sample_writer.close()
thread_pool.terminate()
thread_pool.join()
del thread_pool
try:
if shard_status.error_message is not None:
print(shard_status.error_message)
if "[youtube]" in shard_status.error_message: # video-specific error, remove videoID
shard_status.error_message = "ERROR: [youtube]:" + shard_status.error_message.split(":")[-1]
raise ValueError
except Exception: # pylint: disable=broad-except
shard_status.failed["failed_to_download"] += 1
shard_status.status_dict.increment(shard_status.error_message)
metadata["status"] = "failed_to_download"
metadata["error_message"] = shard_status.error_message
shard_sample_writer.write(
{},
str_key,
sample_data[self.caption_indice] if self.caption_indice is not None else None,
metadata,
)
return

for stream in streams.values():
shard_status.bytes_downloaded += len(stream)
for modality in streams:
streams[modality] = [streams[modality]]

process_sample(
subsamplers=self.subsamplers,
shard_status=shard_status,
streams=cast(Streams, streams),
key=str_key,
caption=cast(str, caption),
metadata=metadata,
captions_are_subtitles=self.config["storage"]["captions_are_subtitles"],
shard_sample_writer=shard_sample_writer,
)

shard_sample_writer.close()
rm_shard_path()
end_time = time.time()

write_stats(
self.output_folder,
shard_id,
count,
successes,
failed["failed_to_download"],
failed["failed_to_subsample"],
bytes_downloaded,
shard_status.count,
shard_status.successes,
shard_status.failed["failed_to_download"],
shard_status.failed["failed_to_subsample"],
shard_status.bytes_downloaded,
start_time,
end_time,
status_dict,
shard_status.status_dict,
self.config["storage"]["oom_shard_count"],
)
fs.rm(shard_path)
Loading