Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ runtime_common = [
"build",
"compressed-tensors",
"datasets",
"video-reader-rs",
"fastapi",
"hf_transfer",
"huggingface_hub",
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/check_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def is_cuda_v2():
"tiktoken",
"anthropic",
"litellm",
"decord",
"video-reader-rs",
]


Expand Down
5 changes: 2 additions & 3 deletions python/sglang/srt/multimodal/processors/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch
from PIL import Image
from transformers import BaseImageProcessorFast

from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import load_audio, load_image, load_video, logger

Expand Down Expand Up @@ -206,7 +205,7 @@ def get_estimated_frames_list(self, image_data):
estimate the total frame count from all visual input
"""
# Lazy import because decord is not available on some arm platforms.
from decord import VideoReader, cpu
from video_reader import PyVideoReader, cpu

# Before processing inputs
if not image_data or len(image_data) == 0:
Expand All @@ -216,7 +215,7 @@ def get_estimated_frames_list(self, image_data):
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
# Estimate frames for the video
vr = VideoReader(path, ctx=cpu(0))
vr = PyVideoReader(path, threads=0)
num_frames = len(vr)
else:
# For images, each contributes one frame
Expand Down
23 changes: 8 additions & 15 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
import torch.distributed as dist
import triton
import zmq
from video_reader import PyVideoReader
from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version
from PIL import Image
Expand Down Expand Up @@ -757,24 +758,16 @@ def load_image(

def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
from decord import VideoReader, cpu, gpu

try:
from decord.bridge import decord_bridge

ctx = gpu(0)
_ = decord_bridge.get_ctx_device(ctx)
except Exception:
ctx = cpu(0)

from video_reader import PyVideoReader
device = 'cuda' if use_gpu and torch.cuda.is_available() else None
tmp_file = None
vr = None
try:
if isinstance(video_file, bytes):
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_file)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
elif isinstance(video_file, str):
if video_file.startswith(("http://", "https://")):
timeout = int(os.getenv("REQUEST_TIMEOUT", "10"))
Expand All @@ -784,22 +777,22 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
for chunk in response.iter_content(chunk_size=8192):
tmp_file.write(chunk)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
elif video_file.startswith("data:"):
_, encoded = video_file.split(",", 1)
video_bytes = base64.b64decode(encoded)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_bytes)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
elif os.path.isfile(video_file):
vr = VideoReader(video_file, ctx=ctx)
vr = PyVideoReader(video_file, device=device, threads=0)
else:
video_bytes = base64.b64decode(video_file)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_bytes)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
else:
raise ValueError(f"Unsupported video input type: {type(video_file)}")

Expand Down
Loading