Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 21 additions & 54 deletions python/sglang/srt/managers/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,67 +129,34 @@ def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
Replaces multimodal tokens in input_ids with corresponding pad_values from mm_items.
Each modality (image, audio, video) is handled separately based on its token_id.
"""
pad_values = [item.pad_value for item in mm_inputs.mm_items]
if not pad_values:
# No multimodal items, return original input_ids
if not input_ids or not mm_inputs.mm_items:
return input_ids
if not input_ids:
return []

input_ids_tensor = torch.tensor(input_ids)
device = input_ids_tensor.device
token_ids_tensor = torch.tensor(self.token_ids, device=device)
mask = torch.isin(input_ids_tensor, token_ids_tensor)

if not mask.any():
# No tokens match token_ids, return original input_ids
return input_ids
# Create mapping of token_ids to pad_values for each modality
token_to_pad_mapping = {}

# Find contiguous regions
padded_mask = torch.cat(
(
torch.tensor([False], device=device),
mask,
torch.tensor([False], device=device),
)
)
# Find indices where the mask value changes
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]

# Start indices are where False changes to True
starts = diff_indices[::2]
# End indices are where True changes to False (exclusive index)
ends = diff_indices[1::2]

# Check if the number of regions matches the number of pad values
if len(starts) != len(pad_values):
# Maybe log a warning here?
num_regions = len(starts)
num_pad_values = len(pad_values)
if num_regions > 0 and num_pad_values > 0:
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
:num_regions
]
else: # If no regions or no pad_values, this loop won't run anyway.
pad_values = [] # Ensure pad_values is empty if starts is empty

# Create a copy to modify
output_ids_tensor = input_ids_tensor.clone()

# Replace tokens in each region with the corresponding pad value
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
for i in range(min(len(starts), len(pad_values))):
start_idx = starts[i]
end_idx = ends[i]
pad_value = pad_values[i]
if pad_value is not None: # Ensure pad_value is not None before assignment
output_ids_tensor[start_idx:end_idx] = pad_value
for item in mm_inputs.mm_items:
if item.is_image() and mm_inputs.im_token_id is not None:
token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value
elif item.is_audio() and mm_inputs.audio_token_id is not None:
token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value
elif item.is_video() and mm_inputs.video_token_id is not None:
token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value
else:
logger.warning(f"Skipping region {i} due to None pad_value.")
return output_ids_tensor.tolist()
raise ValueError(f"No multimodal token id provided for {item.modality}")

# Apply replacements for all tokens at once
for token_id, pad_value in token_to_pad_mapping.items():
input_ids_tensor[input_ids_tensor == token_id] = pad_value

ret_input_ids = input_ids_tensor.tolist()

return ret_input_ids


embedding_cache = None
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,15 @@ class Modality(Enum):
VIDEO = auto()
AUDIO = auto()

@staticmethod
def from_str(modality_str: str):
try:
return Modality[modality_str.upper()]
except KeyError:
raise ValueError(
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
)


@dataclasses.dataclass
class MultimodalDataItem:
Expand Down
25 changes: 15 additions & 10 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,20 +484,25 @@ async def _tokenize_one_request(
token_type_ids = encoded.get("token_type_ids", [None])[0]

if self.mm_processor and obj.contains_mm_input():
image_inputs = await self.mm_processor.process_mm_data_async(
if not isinstance(obj.image_data, list):
obj.image_data = [obj.image_data]
if not isinstance(obj.audio_data, list):
obj.audio_data = [obj.audio_data]
mm_inputs = await self.mm_processor.process_mm_data_async(
image_data=obj.image_data,
audio_data=obj.audio_data,
input_text=input_text or input_ids,
request_obj=obj,
max_req_input_len=self.max_req_input_len,
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
if mm_inputs and "input_ids" in mm_inputs:
input_ids = mm_inputs["input_ids"]
else:
image_inputs: Optional[Dict] = None
mm_inputs = None

self._validate_one_request(obj, input_ids)
return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
)

def _validate_one_request(
Expand Down Expand Up @@ -553,7 +558,7 @@ def _create_tokenized_object(
input_text: str,
input_ids: List[int],
input_embeds: Optional[Union[List[float], None]] = None,
image_inputs: Optional[Dict] = None,
mm_inputs: Optional[Dict] = None,
token_type_ids: Optional[List[int]] = None,
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
"""Create a tokenized request object from common parameters."""
Expand All @@ -578,7 +583,7 @@ def _create_tokenized_object(
obj.rid,
input_text,
input_ids,
image_inputs,
mm_inputs,
sampling_params,
obj.return_logprob,
obj.logprob_start_len,
Expand All @@ -600,7 +605,7 @@ def _create_tokenized_object(
obj.rid,
input_text,
input_ids,
image_inputs,
mm_inputs,
token_type_ids,
sampling_params,
)
Expand Down Expand Up @@ -638,9 +643,9 @@ def _validate_batch_tokenization_constraints(
) -> None:
"""Validate constraints for batch tokenization processing."""
for i in range(batch_size):
if self.is_generation and obj[i].image_data:
if self.is_generation and obj[i].contains_mm_input():
raise ValueError(
"For image input processing do not set `enable_tokenizer_batch_encode`."
"For multimodal input processing do not set `enable_tokenizer_batch_encode`."
)
if obj[i].input_ids is not None:
raise ValueError(
Expand Down
19 changes: 3 additions & 16 deletions python/sglang/srt/models/gemma3n_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
Expand Down Expand Up @@ -250,20 +250,8 @@ def pad_input_ids(
if mm_inputs is None:
return input_ids

# Collect available media token pairs
media_token_pairs = []
for attr_name in ["im_start_id", "audio_start_id"]:
if hasattr(mm_inputs, attr_name):
start_id = getattr(mm_inputs, attr_name)
end_id = getattr(mm_inputs, attr_name.replace("start", "end"))
media_token_pairs.append((start_id, end_id))

# Apply padding pattern if we have media tokens
if media_token_pairs:
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, mm_inputs)

return input_ids
pattern = MultiModalityDataPaddingPatternMultimodalTokens(None)
return pattern.pad_input_tokens(input_ids, mm_inputs)

def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings()
Expand Down Expand Up @@ -431,7 +419,6 @@ def forward(
)

positions += 1

if input_ids is not None:
# Prepare per-layer inputs from inputs_ids
per_layer_inputs_mask = torch.logical_and(
Expand Down
Loading
Loading