Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
0096460
draft update two models for now
zucchini-nlp Jul 15, 2025
34a2ff1
batch update all VLMs first
zucchini-nlp Jul 18, 2025
4a8b169
update some more image processors
zucchini-nlp Jul 23, 2025
c7ec229
merge main
zucchini-nlp Jul 23, 2025
0c1024a
update
zucchini-nlp Jul 23, 2025
01ad83a
fix a few tests
zucchini-nlp Jul 23, 2025
3187cc4
just make CI green for now
zucchini-nlp Jul 23, 2025
7900ce8
fix copies
zucchini-nlp Jul 24, 2025
47cce51
update once more
zucchini-nlp Jul 24, 2025
ea9a29b
update
zucchini-nlp Jul 24, 2025
4424f55
merge main
zucchini-nlp Jul 24, 2025
9957f9d
unskip the test
zucchini-nlp Jul 24, 2025
1c9ad58
fix these two
zucchini-nlp Jul 28, 2025
595fe00
fix torchcodec audio loading
zucchini-nlp Jul 28, 2025
3f66126
maybe
zucchini-nlp Aug 1, 2025
4c5a674
merge main
zucchini-nlp Aug 1, 2025
ba02dec
yay, i fixed torchcodec installation and now can actually test it
zucchini-nlp Aug 1, 2025
272054f
Merge remote-tracking branch 'upstream/main' into video-decoding
zucchini-nlp Aug 4, 2025
c05f31c
fix copies deepseek
zucchini-nlp Aug 4, 2025
0fe6e26
make sure the metadata is returrned when users request it
zucchini-nlp Aug 4, 2025
6286a8a
add docs
zucchini-nlp Aug 4, 2025
1a78709
update
zucchini-nlp Aug 4, 2025
c9562f4
merge main
zucchini-nlp Aug 4, 2025
86ab24a
fixup
zucchini-nlp Aug 4, 2025
f8f2506
Merge branch 'main' into video-decoding
zucchini-nlp Aug 5, 2025
0997fda
Update src/transformers/audio_utils.py
zucchini-nlp Aug 7, 2025
40f6f3f
Update src/transformers/models/glm4v/video_processing_glm4v.py
zucchini-nlp Aug 7, 2025
ce9750d
update
zucchini-nlp Aug 7, 2025
7375c02
what if we set some metadata attr to `None`
zucchini-nlp Aug 7, 2025
66eed23
fix CI
zucchini-nlp Aug 7, 2025
194d714
fix one test
zucchini-nlp Aug 7, 2025
adc8299
fix 4 channel test
zucchini-nlp Aug 7, 2025
1b34bba
fix glm timestemps
zucchini-nlp Aug 8, 2025
5f5c959
Merge branch 'main' into video-decoding
zucchini-nlp Aug 8, 2025
4b4f297
Merge branch 'main' into video-decoding
zucchini-nlp Aug 12, 2025
d5723de
rebase gone wrong
zucchini-nlp Aug 14, 2025
bc0405f
raise warning once
zucchini-nlp Aug 19, 2025
de9f7fb
Merge branch 'main' into video-decoding
zucchini-nlp Aug 19, 2025
baeb0e4
fixup
zucchini-nlp Aug 19, 2025
1fb826c
typo
zucchini-nlp Aug 19, 2025
c0a1c62
fix copies
zucchini-nlp Aug 20, 2025
248716c
ifx smolvlm test
zucchini-nlp Aug 20, 2025
ca0e8ae
this is why torch's official benchmark was faster, set threads to `0`
zucchini-nlp Aug 21, 2025
ceeab83
Merge branch 'main' into video-decoding
zucchini-nlp Aug 25, 2025
bf7e1d3
Merge branch 'main' into video-decoding
zucchini-nlp Aug 25, 2025
cd4f073
Merge branch 'main' into video-decoding
zucchini-nlp Aug 25, 2025
97e672f
Merge branch 'main' into video-decoding
zucchini-nlp Aug 26, 2025
1fb502b
Apply style fixes
github-actions[bot] Aug 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/source/en/main_classes/image_processor.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ rendered properly in your Markdown viewer.

# Image Processor

An image processor is in charge of preparing input features for vision models and post processing their outputs. This includes transformations such as resizing, normalization, and conversion to Numpy and PyTorch tensors. It may also include model specific post-processing such as converting logits to segmentation masks.

An image processor is in charge of loading images (optionally), preparing input features for vision models and post processing their outputs. This includes transformations such as resizing, normalization, and conversion to PyTorch and Numpy tensors. It may also include model specific post-processing such as converting logits to segmentation masks.
Fast image processors are available for a few models and more will be added in the future. They are based on the [torchvision](https://pytorch.org/vision/stable/index.html) library and provide a significant speed-up, especially when processing on GPU.
They have the same API as the base image processors and can be used as drop-in replacements.
To use a fast image processor, you need to install the `torchvision` library, and set the `use_fast` argument to `True` when instantiating the image processor:
Expand Down
44 changes: 42 additions & 2 deletions docs/source/en/main_classes/video_processor.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ rendered properly in your Markdown viewer.

-->


# Video Processor

A **Video Processor** is a utility responsible for preparing input features for video models, as well as handling the post-processing of their outputs. It provides transformations such as resizing, normalization, and conversion into PyTorch.
A **Video Processor** is a utility responsible for preparing input features for video models, as well as handling the post-processing of their outputs. It provides transformations such as resizing, normalization, and conversion into PyTorch. Along ith transformations the `VideoProcessor` class handles video decoding from local paths or URLs (requires [`torchcodec`](https://pypi.org/project/torchcodec/)) and frame sampling according to model-specific strategies.

The video processor extends the functionality of image processors by allowing Vision Large Language Models (VLMs) to handle videos with a distinct set of arguments compared to images. It serves as the bridge between raw video data and the model, ensuring that input features are optimized for the VLM.

Expand Down Expand Up @@ -48,6 +47,47 @@ processor = torch.compile(processor)
processed_video = processor(video, return_tensors="pt")
```

#### Sampling behavior

The video processor can also sample video frames using the technique best suited for the given model. Sampling behavior is controlled with the `do_sample_frames` argument and can be configured through model-specific parameters such as `num_frames` or `fps` (the rate at which the video will be sampled). If the input video is given as a local path or URL (`str`), the processor will decode it automatically. To obtain metadata about the decoded video, such as sampled frame indices, original dimensions, duration, and fps, pass `return_metadata=True` to the processor.

<Tip warning={false}>

- Specifying `num_frames` does not guarantee the output will contain exactly that number of frames. Depending on the model, the sampler may enforce minimum or maximum frame limits.

- The default decoder is [`torchcodec`](https://pypi.org/project/torchcodec/), which must be installed.

</Tip>


```python
from transformers import AutoVideoProcessor

processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device="cuda")
processed_video_inputs = processor(videos=["video_path.mp4"], return_metadata=True, do_sample_frames=True, return_tensors="pt")
video_metadata = processed_video_inputs["video_metadata"]

# See how many frames the original video had and what was the original FPS
print(video_metadata.total_num_frames, video_metadata.fps)
```

If you pass an already decoded video array but still want to enable model-specific frame sampling, it is strongly recommended to provide video_metadata. This allows the sampler to know the original video’s duration and FPS. You can pass metadata as a `VideoMetadata` object or as a plain dict.

```python
from transformers import AutoVideoProcessor
from transformers.video_utils import VideoMetadata

processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device="cuda")
my_decodec_video = torch.randint(0, 255, size=(100, 3, 1280, 1280)) # short video of 100 frames
video_metadata = VideoMetadata(
total_num_frames=100,
fps=24,
duration=4.1, # in seconds
)
processed_video_inputs = processor(videos=["video_path.mp4"], video_metadata=video_metadata, do_sample_frames=True, num_frames=10, return_tensors="pt")
print(processed_video_inputs.pixel_values_videos.shape)
>>> [10, 3, 384, 384]
```

## BaseVideoProcessor

Expand Down
76 changes: 63 additions & 13 deletions src/transformers/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
import warnings
from io import BytesIO
from typing import Any, Optional, Union
from typing import Any, Optional, Sequence, Union

import numpy as np
import requests
Expand All @@ -31,6 +31,7 @@
is_numpy_array,
is_soundfile_available,
is_torch_tensor,
is_torchcodec_available,
requires_backends,
)

Expand All @@ -44,6 +45,12 @@
# TODO: @eustlb, we actually don't need librosa but soxr is installed with librosa
import soxr

if is_torchcodec_available():
from torchcodec.decoders import AudioDecoder


AudioInput = Union[np.ndarray, "torch.Tensor", Sequence[np.ndarray], Sequence["torch.Tensor"]] # noqa: F821


def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray:
"""
Expand All @@ -61,14 +68,14 @@ def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None)
Returns:
`np.ndarray`: A numpy array representing the audio.
"""
requires_backends(load_audio, ["librosa"])

if isinstance(audio, str):
# Load audio from URL (e.g https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav)
if audio.startswith("http://") or audio.startswith("https://"):
audio = librosa.load(BytesIO(requests.get(audio, timeout=timeout).content), sr=sampling_rate)[0]
elif os.path.isfile(audio):
audio = librosa.load(audio, sr=sampling_rate)[0]
# Try to load with `torchcodec` but do not enforce users to install it. If not found
# fallback to `librosa`. If using an audio-only model, most probably `torchcodec` won't be
# needed.
if is_torchcodec_available():
audio = load_audio_torchcodec(audio, sampling_rate=sampling_rate)
else:
audio = load_audio_librosa(audio, sampling_rate=sampling_rate, timeout=timeout)
elif isinstance(audio, np.ndarray):
audio = audio
else:
Expand All @@ -78,6 +85,54 @@ def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None)
return audio


def load_audio_torchcodec(audio: Union[str, np.ndarray], sampling_rate=16000) -> np.ndarray:
"""
Loads `audio` to an np.ndarray object using `torchcodec`.

Args:
audio (`str` or `np.ndarray`):
The audio to be loaded to the numpy array format.
sampling_rate (`int`, *optional*, defaults to 16000):
The sampling rate to be used when loading the audio. It should be same as the
sampling rate the model you will be using further was trained with.

Returns:
`np.ndarray`: A numpy array representing the audio.
"""
requires_backends(load_audio, ["torchcodec"])

# Set `num_channels` to `1` which is what most models expects and the default in librosa
decoder = AudioDecoder(audio, sample_rate=sampling_rate, num_channels=1)
audio = decoder.get_all_samples().data[0].numpy() # NOTE: feature extractors don't accept torch tensors
return audio


def load_audio_librosa(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray:
"""
Loads `audio` to an np.ndarray object using `librosa`.

Args:
audio (`str` or `np.ndarray`):
The audio to be loaded to the numpy array format.
sampling_rate (`int`, *optional*, defaults to 16000):
The sampling rate to be used when loading the audio. It should be same as the
sampling rate the model you will be using further was trained with.
timeout (`float`, *optional*):
The timeout value in seconds for the URL request.

Returns:
`np.ndarray`: A numpy array representing the audio.
"""
requires_backends(load_audio, ["librosa"])

# Load audio from URL (e.g https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav)
if audio.startswith("http://") or audio.startswith("https://"):
audio = librosa.load(BytesIO(requests.get(audio, timeout=timeout).content), sr=sampling_rate)[0]
elif os.path.isfile(audio):
audio = librosa.load(audio, sr=sampling_rate)[0]
return audio


def load_audio_as(
audio: str,
return_format: str,
Expand Down Expand Up @@ -157,11 +212,6 @@ def load_audio_as(
raise ValueError(f"Error loading audio: {e}")


AudioInput = Union[
np.ndarray, "torch.Tensor", list[np.ndarray], tuple[np.ndarray], list["torch.Tensor"], tuple["torch.Tensor"] # noqa: F821
]


def is_valid_audio(audio):
return is_numpy_array(audio) or is_torch_tensor(audio)

Expand Down
22 changes: 5 additions & 17 deletions src/transformers/image_processing_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
import json
import os
import warnings
from io import BytesIO
from typing import Any, Optional, TypeVar, Union

import numpy as np
import requests

from .dynamic_module_utils import custom_object_save
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
from .image_utils import is_valid_image, load_image
from .utils import (
IMAGE_PROCESSOR_NAME,
PushToHubMixin,
Expand All @@ -33,15 +32,10 @@
download_url,
is_offline_mode,
is_remote_url,
is_vision_available,
logging,
)


if is_vision_available():
from PIL import Image


ImageProcessorType = TypeVar("ImageProcessorType", bound="ImageProcessingMixin")


Expand Down Expand Up @@ -514,25 +508,19 @@ def register_for_auto_class(cls, auto_class="AutoImageProcessor"):

cls._auto_class = auto_class

def fetch_images(self, image_url_or_urls: Union[str, list[str]]):
def fetch_images(self, image_url_or_urls: Union[str, list[str], list[list[str]]]):
"""
Convert a single or a list of urls into the corresponding `PIL.Image` objects.

If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
returned.
"""
headers = {
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
" Safari/537.36"
)
}
if isinstance(image_url_or_urls, list):
return [self.fetch_images(x) for x in image_url_or_urls]
elif isinstance(image_url_or_urls, str):
response = requests.get(image_url_or_urls, stream=True, headers=headers)
response.raise_for_status()
return Image.open(BytesIO(response.content))
return load_image(image_url_or_urls)
elif is_valid_image(image_url_or_urls):
return image_url_or_urls
else:
raise TypeError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")

Expand Down
31 changes: 15 additions & 16 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def validate_fast_preprocess_arguments(
crop_size: Optional[SizeDict] = None,
do_resize: Optional[bool] = None,
size: Optional[SizeDict] = None,
resample: Optional["PILImageResampling"] = None,
interpolation: Optional["F.InterpolationMode"] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
):
Expand All @@ -105,7 +105,7 @@ def validate_fast_preprocess_arguments(
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
interpolation=interpolation,
)
# Extra checks for ImageProcessorFast
if return_tensors is not None and return_tensors != "pt":
Expand Down Expand Up @@ -469,6 +469,8 @@ def _prepare_images_structure(
Returns:
`ImageInput`: The images with a valid nesting.
"""
# Checks for `str` in case of URL/local path and optionally loads images
images = self.fetch_images(images)
return make_flat_list_of_images(images, expected_ndims=expected_ndims)

def _process_image(
Expand Down Expand Up @@ -582,11 +584,19 @@ def _further_process_kwargs(

kwargs["size"] = size
kwargs["crop_size"] = crop_size
kwargs["default_to_square"] = default_to_square
kwargs["image_mean"] = image_mean
kwargs["image_std"] = image_std
kwargs["data_format"] = data_format

# torch resize uses interpolation instead of resample
# Check if resample is an int before checking if it's an instance of PILImageResampling
# because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
# Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
resample = kwargs.pop("resample")
kwargs["interpolation"] = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
)

return kwargs

def _validate_preprocess_kwargs(
Expand All @@ -600,7 +610,7 @@ def _validate_preprocess_kwargs(
size: Optional[SizeDict] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[SizeDict] = None,
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
interpolation: Optional["F.InterpolationMode"] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
**kwargs,
Expand All @@ -618,7 +628,7 @@ def _validate_preprocess_kwargs(
size=size,
do_center_crop=do_center_crop,
crop_size=crop_size,
resample=resample,
interpolation=interpolation,
return_tensors=return_tensors,
data_format=data_format,
)
Expand Down Expand Up @@ -646,18 +656,7 @@ def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImag
# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)

# torch resize uses interpolation instead of resample
resample = kwargs.pop("resample")

# Check if resample is an int before checking if it's an instance of PILImageResampling
# because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
# Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
kwargs["interpolation"] = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (int, PILImageResampling)) else resample
)

# Pop kwargs that are not needed in _preprocess
kwargs.pop("default_to_square")
kwargs.pop("data_format")

return self._preprocess_image_like_inputs(
Expand Down
10 changes: 8 additions & 2 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def validate_preprocess_arguments(
do_resize: Optional[bool] = None,
size: Optional[dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
interpolation: Optional["InterpolationMode"] = None,
):
"""
Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
Expand All @@ -559,8 +560,13 @@ def validate_preprocess_arguments(
if do_center_crop and crop_size is None:
raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")

if do_resize and (size is None or resample is None):
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
if interpolation is not None and resample is not None:
raise ValueError(
"Only one of `interpolation` and `resample` should be specified, depending on image processor type."
)

if do_resize and not (size is not None and (resample is not None or interpolation is not None)):
raise ValueError("`size` and `resample/interpolation` must be specified if `do_resize` is `True`.")


# In the future we can add a TF implementation here when we have TF models.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aria/image_processing_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def preprocess(
if max_image_size not in [490, 980]:
raise ValueError("max_image_size must be either 490 or 980")

images = self.fetch_images(images)
images = make_flat_list_of_images(images)

if not valid_images(images):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ def preprocess(
if max_image_size not in [490, 980]:
raise ValueError("max_image_size must be either 490 or 980")

images = self.fetch_images(images)
images = make_flat_list_of_images(images)

if not valid_images(images):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def __call__(
# Process images
image_inputs = {}
if images is not None:
images = self.image_processor.fetch_images(images)
images = make_flat_list_of_images(images)
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
num_patches = image_inputs.pop("num_patches")
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/blip/image_processing_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def preprocess(

size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
images = self.fetch_images(images)
images = make_flat_list_of_images(images)

if not valid_images(images):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ def preprocess(

size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
images = self.fetch_images(images)
images = make_flat_list_of_images(images)

if not valid_images(images):
Expand Down
Loading