Skip to content
48 changes: 27 additions & 21 deletions fiftyone/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,32 +1009,37 @@ def compute_embeddings(
else:
field_mapping = None

process_video_frames = (
samples.media_type == fom.VIDEO and model.media_type == "image"
)

use_data_loader = (
isinstance(model, (SupportsGetItem, TorchModelMixin))
and not process_video_frames
)
with contextlib.ExitStack() as context:
if hasattr(model, "mode") and model.mode is None:
context.enter_context(
fou.SetAttributes(model, mode=samples.media_type)
)

if num_workers is not None and not use_data_loader:
logger.warning("Ignoring unsupported `num_workers` parameter")
process_video_frames = (
samples.media_type == fom.VIDEO and model.media_type == "image"
)

if embeddings_field is not None:
dataset = samples._dataset
embeddings_field, _is_frame_field = dataset._handle_frame_field(
embeddings_field
use_data_loader = (
isinstance(model, (SupportsGetItem, TorchModelMixin))
and not process_video_frames
)

if dataset.media_type == fom.VIDEO and model.media_type == "image":
if not dataset.has_frame_field(embeddings_field):
dataset.add_frame_field(embeddings_field, fof.VectorField)
else:
if not dataset.has_sample_field(embeddings_field):
dataset.add_sample_field(embeddings_field, fof.VectorField)
if num_workers is not None and not use_data_loader:
logger.warning("Ignoring unsupported `num_workers` parameter")

if embeddings_field is not None:
dataset = samples._dataset
embeddings_field, _is_frame_field = dataset._handle_frame_field(
embeddings_field
)

if dataset.media_type == fom.VIDEO and model.media_type == "image":
if not dataset.has_frame_field(embeddings_field):
dataset.add_frame_field(embeddings_field, fof.VectorField)
else:
if not dataset.has_sample_field(embeddings_field):
dataset.add_sample_field(embeddings_field, fof.VectorField)

with contextlib.ExitStack() as context:
if use_data_loader:
context.enter_context(fou.SetAttributes(model, preprocess=False))

Expand Down Expand Up @@ -1449,6 +1454,7 @@ def _compute_video_embeddings(
raise e

errors = True
embedding = None
logger.warning("Sample: %s\nError: %s\n", sample.id, e)

if embeddings_field is not None:
Expand Down
155 changes: 138 additions & 17 deletions fiftyone/utils/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import json
import logging

import eta.core.video as etav

import numpy as np

import fiftyone.core.labels as fol
Expand All @@ -25,12 +27,17 @@


def _ensure_qwen3_vl():
fou.ensure_package("transformers>=4.51.0")
fou.ensure_package("transformers>=4.57.0")
fou.ensure_package("accelerate")
fou.ensure_package("qwen-vl-utils")
# 0.0.14 double-applies smart_resize on frame-list videos
# (QwenLM/Qwen3-VL#2045); harmless for embeddings.
fou.ensure_package("qwen-vl-utils>=0.0.1")


transformers = fou.lazy_import("transformers", callback=_ensure_qwen3_vl)
qwen_vl_utils = fou.lazy_import("qwen_vl_utils", callback=_ensure_qwen3_vl)

from PIL import Image as PILImage


DEFAULT_QWEN3_VL_MODEL = "Qwen/Qwen3-VL-2B-Instruct"
Expand Down Expand Up @@ -146,6 +153,12 @@ class Qwen3VLModelConfig(fout.TorchImageModelConfig, fozm.HasZooModel):
embedding_dim (None): output embedding dimension for MRL truncation;
if None, uses full model dimension (2048 for 2B, 3584 for 8B)
normalize_embeddings (True): whether to L2 normalize embeddings
video_fps (2.0): frame sampling rate for video inputs; Qwen3-VL's
default is 2.0 FPS. Lower values = fewer frames = faster
max_video_frames (128): maximum frames to sample from a video;
prevents OOM on long videos. Matches qwen-vl-utils MAX_FRAMES.
mode (None): the media type mode, "image" or "video"; if None,
defaults to the dataset's media type at inference time
"""

def __init__(self, d):
Expand All @@ -162,6 +175,21 @@ def __init__(self, d):
self.normalize_embeddings = self.parse_bool(
d, "normalize_embeddings", default=True
)
self.video_fps = self.parse_number(d, "video_fps", default=2.0)
if self.video_fps <= 0:
raise ValueError(
f"video_fps must be positive, got {self.video_fps}"
)
self.max_video_frames = self.parse_int(d, "max_video_frames", default=128)
if self.max_video_frames <= 0:
raise ValueError(
f"max_video_frames must be positive, got {self.max_video_frames}"
)
self.mode = self.parse_string(d, "mode", default=None)
if self.mode is not None and self.mode not in ("image", "video"):
raise ValueError(
"mode must be 'image', 'video', or None; got %r" % self.mode
)

self.raw_inputs = True

Expand Down Expand Up @@ -217,8 +245,21 @@ class Qwen3VLModel(fout.TorchImageModel, fom.EmbeddingsMixin):

def __init__(self, config):
self._processor = None
self._mode = config.mode
super().__init__(config)

@property
def mode(self):
return self._mode

@mode.setter
def mode(self, value):
if value not in (None, "image", "video"):
raise ValueError(
"mode must be 'image', 'video', or None; got %r" % value
)
self._mode = value

@property
def has_embeddings(self):
return self._output_processor is None
Expand All @@ -245,7 +286,7 @@ def _load_model(self, config):

@property
def media_type(self):
return "image"
return self._mode or "image"

def _get_prompt(self):
if self.config.prompt is not None:
Expand All @@ -269,8 +310,6 @@ def _forward_pass(self, imgs):

def _generate_detections(self, imgs):
"""Generate detection output via text generation."""
from PIL import Image as PILImage

prompt = self._get_prompt()
results = []

Expand Down Expand Up @@ -349,23 +388,24 @@ def _embed_images(self, imgs):
return_dict=True,
)

last_hidden = outputs.hidden_states[-1]
embedding = last_hidden[:, -1, :]
embeddings.append(self._postprocess_embedding(outputs))

if self.config.embedding_dim is not None:
embedding = embedding[:, : self.config.embedding_dim]
return np.vstack(embeddings)

if self.config.normalize_embeddings:
embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)
def _postprocess_embedding(self, outputs):
last_hidden = outputs.hidden_states[-1]
embedding = last_hidden[:, -1, :]

embeddings.append(embedding.float().cpu().numpy())
if self.config.embedding_dim is not None:
embedding = embedding[:, : self.config.embedding_dim]

return np.vstack(embeddings)
if self.config.normalize_embeddings:
embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)

return embedding.float().cpu().numpy()

def _prepare_image(self, img):
"""Convert image to PIL format for processor."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Re: line +406]

remove import

See this comment inline on Graphite.

from PIL import Image as PILImage

if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
# Transpose CHW to HWC if first dim is channels and last dim is not
Expand All @@ -383,14 +423,19 @@ def _prepare_image(self, img):
return img

def embed(self, arg):
"""Generate embedding for a single image.
"""Generate embedding for a single image or video.

Args:
arg: a PIL image, numpy array, or torch tensor
arg: a PIL image, numpy array, torch tensor, or an active
(entered) ``eta.core.video.FFmpegVideoReader`` context
manager

Returns:
a 1D numpy array embedding
"""
if isinstance(arg, etav.FFmpegVideoReader):
return self._embed_video(arg)

return self.embed_all([arg])[0]

def embed_all(self, args):
Expand All @@ -403,3 +448,79 @@ def embed_all(self, args):
a ``num_images x embedding_dim`` numpy array
"""
return self._predict_all(args)

def _embed_video(self, video_reader):
"""Generate a single embedding for a video via native Qwen3-VL video input.

Samples frames at ``config.video_fps`` and passes them as a video
message to Qwen3-VL, which processes the full temporal context
and returns one embedding vector.

Args:
video_reader: an ``eta.core.video.FFmpegVideoReader``

Returns:
a 1D numpy array embedding
"""
raw_fps = video_reader.frame_rate
sample_fps = self.config.video_fps

if raw_fps > 0 and sample_fps > 0:
step = max(1, round(raw_fps / sample_fps))
else:
step = 1

frames = []
for i, frame in enumerate(video_reader):
if i % step == 0:
frames.append(self._prepare_image(frame))
if len(frames) >= self.config.max_video_frames:
break

if not frames:
raise ValueError(
"No frames could be sampled from the video; "
"the file may be empty or unreadable"
)

effective_fps = raw_fps / step if raw_fps > 0 else sample_fps

messages = [
{
"role": "user",
"content": [
{
"type": "video",
"video": frames,
"fps": effective_fps,
},
],
}
]

image_inputs, video_inputs = qwen_vl_utils.process_vision_info(messages)
text = self._processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
inputs = self._processor(
text=[text],
images=image_inputs,
videos=video_inputs,
return_tensors="pt",
padding=True,
)
inputs = {
k: v.to(self._model.device) if hasattr(v, "to") else v
for k, v in inputs.items()
}

with torch.no_grad():
outputs = self._model(
**inputs,
output_hidden_states=True,
return_dict=True,
)

return self._postprocess_embedding(outputs).squeeze(0)
Loading
Loading