Skip to content
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e68cd14
refactor: unify feature field of MultimodalDataItem
mickqian Jul 16, 2025
47e3b99
update
mickqian Jul 16, 2025
159868a
update
mickqian Jul 16, 2025
b3c745e
update
mickqian Jul 16, 2025
874375d
update
mickqian Jul 16, 2025
ddc2623
update
mickqian Jul 16, 2025
1400d00
tensor_transport_mode
mickqian May 3, 2025
4f2e3f9
tp=0 and tp>0 working
mickqian May 3, 2025
5f81bb0
working
mickqian Jun 8, 2025
4f1b017
fallback on tp
mickqian Jun 9, 2025
22ff65e
cleanup
mickqian Jun 9, 2025
46faf28
cleanup
mickqian Jun 9, 2025
fc3ef78
refactor with TransportableTensor
mickqian Jun 10, 2025
6cd047a
work with tp
mickqian Jun 10, 2025
ea51000
update
mickqian Jun 10, 2025
4658110
remove deepcopy
mickqian Jun 11, 2025
fcbe5c3
cleanup
mickqian Jun 11, 2025
ddd2585
cleanup
mickqian Jun 11, 2025
7f2b92b
replace awkward __getstate__
mickqian Jul 16, 2025
9bf97b8
update
mickqian Jul 16, 2025
bd222f5
cleanup
mickqian Jul 17, 2025
31bb7a2
cleanup
mickqian Jul 17, 2025
ab8cdd3
revert mmmu related
mickqian Jul 17, 2025
736acfb
revert mmmu related
mickqian Jul 17, 2025
2c2e323
cleanup
mickqian Jul 17, 2025
81ba186
processor
mickqian Jul 17, 2025
7a12fbb
TransportableTensor __getstate__
mickqian Jul 17, 2025
aa5c34d
rename
mickqian Jul 18, 2025
2559180
update
mickqian Jul 18, 2025
886bd05
fix
mickqian Jul 18, 2025
56c07c1
upd
mickqian Jul 19, 2025
631d4b5
upd
mickqian Jul 21, 2025
f2379c4
Merge branch 'main' into optimize_req
JustinTong0323 Jul 24, 2025
a86de63
remove duplicate multimodal_processors/qwen_audio.py
JustinTong0323 Jul 24, 2025
fded26d
refactor: update terminology from 'precomputed features' to 'precompu…
JustinTong0323 Jul 24, 2025
deec266
refactor: update constructors in multimodal processors to accept addi…
JustinTong0323 Jul 24, 2025
ef75f54
fix
mickqian Jul 24, 2025
8cd0b75
fix: rename precomputed_features to precomputed_embeddings
JustinTong0323 Jul 25, 2025
54a8f31
Merge branch 'main' into optimize_req
JustinTong0323 Jul 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion python/sglang/srt/managers/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
"""

import hashlib
import pickle
from abc import abstractmethod
from typing import Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

import numpy as np
import torch
Expand All @@ -27,6 +28,130 @@
# propagation that can cause some log messages (like 'server is fired up') to not appear
# in the console when multimodal support is enabled.

# TODO(mick): nccl
# cuda_ipc: for intranode tensor sharing
TensorTransportMode = Literal["cuda_ipc", "auto", "default"]


class TransportProxyTensor(torch.Tensor):
"""
A convenient torch.Tensor subclass that carries extra metadata and supports
efficient inter-process communications
"""

@staticmethod
def __new__(
cls,
data: torch.Tensor,
name: Optional[str] = None,
fields: Optional[Dict[str, Any]] = None,
transport_mode: TensorTransportMode = "default",
*args,
**kwargs,
):

if not isinstance(data, torch.Tensor):
raise TypeError(
f"Input 'data' must be a torch.Tensor, but got {type(data)}"
)

instance = data.as_subclass(cls)

instance._metadata = {
"name": name,
"fields": fields if fields is not None else {},
"transport_mode": transport_mode,
}

return instance

def __getstate__(self):
"""
Called during pickling. Implements the serialization logic.
"""
# acquire all serialize metadata from _metadata
state = {
"metadata": self._metadata,
"tensor_data": None,
"ipc_extra": None,
}

transport_mode = self._metadata.get("transport_mode", "default")

if transport_mode == "cuda_ipc" and self.is_cuda:
try:
storage = self.untyped_storage()
handle = storage._share_cuda_()

state["ipc_extra"] = {
"handle": handle,
"shape": self.shape,
"dtype": self.dtype,
"stride": self.stride(),
"device_index": self.device.index,
}
state["tensor_data"] = None
except Exception as e:
print_warning_once(
f"Warning: Failed to get CUDA IPC handle ({e}). Falling back to default transport."
)
state["metadata"]["transport_mode"] = "default"
state["tensor_data"] = self.as_subclass(torch.Tensor)
else:
state["metadata"]["transport_mode"] = "default"
state["tensor_data"] = self.as_subclass(torch.Tensor)

return state

def __setstate__(self, state: Dict[str, Any]):
"""
Called during unpickling. Implements the deserialization logic.
"""
self._metadata = state["metadata"]

transport_mode = self._metadata.get("transport_mode", "default")

if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None:
ipc_extra = state["ipc_extra"]
handle, shape, dtype, stride, source_device_index = (
ipc_extra["handle"],
ipc_extra["shape"],
ipc_extra["dtype"],
ipc_extra["stride"],
ipc_extra["device_index"],
)

try:
target_device = torch.device(f"cuda:{source_device_index}")
with torch.cuda.device(target_device):
storage = torch.UntypedStorage._new_shared_cuda(*handle)
reconstructed_tensor = torch.empty(
0, dtype=dtype, device=target_device
).set_(storage, storage_offset=0, size=shape, stride=stride)
self.set_(reconstructed_tensor)
except Exception as e:
print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).")
raise e

elif state["tensor_data"] is not None:
self.set_(state["tensor_data"])
else:
raise pickle.UnpicklingError(
"Invalid state for TransportProxyTensor: no tensor data found."
)

@property
def name(self) -> Optional[str]:
return self._metadata.get("name")

@property
def fields(self) -> Dict[str, Any]:
return self._metadata.get("fields", {})

@property
def transport_mode(self) -> TensorTransportMode:
return self._metadata.get("transport_mode", "default")


class MultiModalityDataPaddingPattern:
"""
Expand Down
17 changes: 3 additions & 14 deletions python/sglang/srt/managers/multimodal_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,6 @@
PROCESSOR_MAPPING = {}


class DummyMultimodalProcessor(BaseMultimodalProcessor):
def __init__(self):
pass

async def process_mm_data_async(self, *args, **kwargs):
return None


def get_dummy_processor():
return DummyMultimodalProcessor()


def import_processors():
package_name = "sglang.srt.multimodal.processors"
package = importlib.import_module(package_name)
Expand All @@ -49,11 +37,12 @@ def import_processors():


def get_mm_processor(
hf_config, server_args: ServerArgs, processor
hf_config, server_args: ServerArgs, processor, transport_mode
) -> BaseMultimodalProcessor:
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
if model_cls.__name__ in hf_config.architectures:
return processor_cls(hf_config, server_args, processor)
return processor_cls(hf_config, server_args, processor, transport_mode)

raise ValueError(
f"No processor registered for architecture: {hf_config.architectures}.\n"
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,11 @@ class MultimodalDataItem:
hash: int = None
pad_value: int = None
offsets: Optional[list] = None

# the raw features returned by processor, e.g. pixel_values or audio_features
feature: Union[torch.Tensor, np.ndarray] = None

# the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio
# the precomputed embeddings, passed as final encoder embeddings
# One and only one of the feature and precomputed_embeddings will be empty
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None

# Model-specific data stored in a dictionary
Expand Down
14 changes: 13 additions & 1 deletion python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.managers.mm_utils import TensorTransportMode
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
Expand Down Expand Up @@ -166,6 +167,16 @@ class ReqState:
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)


def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
is_cross_node = server_args.dist_init_addr

if is_cross_node:
# Fallback to default CPU transport for multi-node
return "default"
else:
return "cuda_ipc"


class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text."""

Expand Down Expand Up @@ -216,12 +227,13 @@ def __init__(
revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor,
)
transport_mode = _determine_tensor_transport_mode(self.server_args)

# We want to parallelize the image pre-processing so we create an executor for it
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
# images even with skip_tokenizer_init=False.
self.mm_processor = get_mm_processor(
self.model_config.hf_config, server_args, _processor
self.model_config.hf_config, server_args, _processor, transport_mode
)

if server_args.skip_tokenizer_init:
Expand Down
26 changes: 20 additions & 6 deletions python/sglang/srt/multimodal/processors/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from PIL import Image
from transformers import BaseImageProcessorFast

from sglang.srt.managers.mm_utils import TransportProxyTensor
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import load_audio, load_image, load_video, logger

Expand Down Expand Up @@ -142,11 +143,14 @@ def get_combined_regex(self) -> re.Pattern:
class BaseMultimodalProcessor(ABC):
models = []

def __init__(self, hf_config, server_args, _processor):
def __init__(
self, hf_config, server_args, _processor, transport_mode, *args, **kwargs
):
self.hf_config = hf_config
self._processor = _processor
self.arch = hf_config.architectures[0]
self.server_args = server_args
self.transport_mode = transport_mode

# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
Expand Down Expand Up @@ -217,10 +221,6 @@ def process_mm_data(
return_tensors="pt",
**kwargs,
)
if "pixel_values" in result and isinstance(
result["pixel_values"], torch.Tensor
):
result["pixel_values"] = result["pixel_values"].to("cpu")
return result

@abstractmethod
Expand Down Expand Up @@ -500,7 +500,6 @@ def collect_mm_items_from_processor_output(
) -> List[MultimodalDataItem]:
"""Create mm_items directly from processor output."""
items: dict[Modality, MultimodalDataItem] = {}

for attr_name, value in data_dict.items():
if attr_name == "input_ids":
continue
Expand Down Expand Up @@ -624,4 +623,19 @@ def process_and_combine_mm_data(
mm_token_id=mm_token_id,
)

# post-process
for item in all_collected_items:
# replace the feature tensor with a proxy
if isinstance(item.feature, torch.Tensor) and item.feature.is_cuda:
item.feature = TransportProxyTensor(
transport_mode=self.transport_mode, data=item.feature
)
elif (
isinstance(item.precomputed_embeddings, torch.Tensor)
and item.feature.is_cuda
):
item.precomputed_embeddings = TransportProxyTensor(
transport_mode=self.transport_mode, data=item.precomputed_embeddings
)

return all_collected_items, input_ids, ret
4 changes: 2 additions & 2 deletions python/sglang/srt/multimodal/processors/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
class ClipImageProcessor(BaseMultimodalProcessor):
models = [CLIPModel]

def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
_processor
)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/multimodal/processors/deepseek_vl_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
models = [DeepseekVL2ForCausalLM]

def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(
image_token="<image>", image_token_id=self._processor.image_token_id
).build(_processor)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/multimodal/processors/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
models = [Gemma3ForConditionalGeneration]

def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
self.mm_tokens = MultimodalSpecialTokens(
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/multimodal/processors/gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):

models = [Gemma3nForConditionalGeneration]

def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)

self.IM_START_TOKEN_ID = hf_config.boi_token_id
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/multimodal/processors/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
class InternVLImageProcessor(BaseMultimodalProcessor):
models = [InternVLChatModel]

def __init__(self, hf_config, server_args, _image_processor):
super().__init__(hf_config, server_args, _image_processor)
def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
image_size = hf_config.force_image_size or hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size

Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/multimodal/processors/janus_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
class JanusProImageProcessor(BaseMultimodalProcessor):
models = [MultiModalityCausalLM]

def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)

self.mm_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token,
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/multimodal/processors/kimi_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
class KimiVLImageProcessor(SGLangBaseProcessor):
models = [KimiVLForConditionalGeneration]

def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(
image_token="<|media_pad|>",
# TODO: could we convert in MultimodalSpecialTokens?
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/multimodal/processors/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
LlavaMistralForCausalLM,
]

def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)

@staticmethod
def _process_single_image_task(
Expand Down Expand Up @@ -187,7 +187,7 @@ def _get_sgl_processor_cls(self, model_type: str):
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
)

def __init__(self, hf_config, server_args, _processor):
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
assert hasattr(hf_config, "vision_config")
assert hasattr(hf_config, "text_config")
self.vision_config = hf_config.vision_config
Expand All @@ -196,7 +196,7 @@ def __init__(self, hf_config, server_args, _processor):

if vision_type := getattr(self.vision_config, "model_type"):
self.inner = self._get_sgl_processor_cls(vision_type)(
hf_config, server_args, _processor
hf_config, server_args, _processor, *args, **kwargs
)
else:
raise ValueError(
Expand Down
Loading
Loading