Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ runtime_common = [
"build",
"compressed-tensors",
"datasets",
"video-reader-rs",
"fastapi",
"hf_transfer",
"huggingface_hub",
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/check_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def is_cuda_v2():
"tiktoken",
"anthropic",
"litellm",
"decord",
"video-reader-rs",
]


Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/multimodal/processors/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def get_estimated_frames_list(self, image_data):
estimate the total frame count from all visual input
"""
# Lazy import because decord is not available on some arm platforms.
from decord import VideoReader, cpu
from video_reader import PyVideoReader, cpu

# Before processing inputs
if not image_data or len(image_data) == 0:
Expand All @@ -216,7 +216,7 @@ def get_estimated_frames_list(self, image_data):
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
# Estimate frames for the video
vr = VideoReader(path, ctx=cpu(0))
vr = PyVideoReader(path, threads=0)
num_frames = len(vr)
else:
# For images, each contributes one frame
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/multimodal/processors/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,15 @@ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
max_frame = len(vr) - 1
fps = float(vr.get_avg_fps())
fps = float(vr.get_fps())

pixel_values_list, num_patches_list = [], []
transform = InternVLImageProcessor.build_transform(input_size=input_size)
frame_indices = InternVLImageProcessor.get_index(
bound, fps, max_frame, first_idx=0, num_segments=num_segments
)
for frame_index in frame_indices:
img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
img = Image.fromarray(vr[frame_index]).convert("RGB")
img = InternVLImageProcessor.dynamic_preprocess(
img, image_size=input_size, use_thumbnail=True, max_num=max_num
)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/multimodal/processors/qwen_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,10 @@ async def preprocess_video(
# vr: VideoReader, image_factor: int = IMAGE_FACTOR
) -> torch.Tensor:
ele = {}
total_frames, video_fps = len(vr), vr.get_avg_fps()
total_frames, video_fps = len(vr), vr.get_fps()
nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = vr.get_batch(idx)
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
Expand Down
22 changes: 8 additions & 14 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
from torch.profiler import ProfilerActivity, profile, record_function
from torch.utils._contextlib import _DecoratorContextManager
from triton.runtime.cache import FileCacheManager
from video_reader import PyVideoReader

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -757,24 +758,17 @@ def load_image(

def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
from decord import VideoReader, cpu, gpu

try:
from decord.bridge import decord_bridge

ctx = gpu(0)
_ = decord_bridge.get_ctx_device(ctx)
except Exception:
ctx = cpu(0)
from video_reader import PyVideoReader

device = "cuda" if use_gpu and torch.cuda.is_available() else None
tmp_file = None
vr = None
try:
if isinstance(video_file, bytes):
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_file)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
elif isinstance(video_file, str):
if video_file.startswith(("http://", "https://")):
timeout = int(os.getenv("REQUEST_TIMEOUT", "10"))
Expand All @@ -784,22 +778,22 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
for chunk in response.iter_content(chunk_size=8192):
tmp_file.write(chunk)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
elif video_file.startswith("data:"):
_, encoded = video_file.split(",", 1)
video_bytes = base64.b64decode(encoded)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_bytes)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
elif os.path.isfile(video_file):
vr = VideoReader(video_file, ctx=ctx)
vr = PyVideoReader(video_file, device=device, threads=0)
else:
video_bytes = base64.b64decode(video_file)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_bytes)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
vr = PyVideoReader(tmp_file.name, device=device, threads=0)
else:
raise ValueError(f"Unsupported video input type: {type(video_file)}")

Expand Down
Loading