Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 29 additions & 33 deletions python/sglang/srt/managers/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def pad_input_tokens(
"No data_token_pairs provided, RadixAttention might be influenced."
)
return input_ids
start_token_ids = [s for s, _e in data_token_pairs]
end_tokens_ids = [e for _s, e in data_token_pairs]
start_token_ids = {s for s, _e in data_token_pairs}
end_tokens_ids = {e for _s, e in data_token_pairs}

padded_ids = []
last_idx = 0
Expand Down Expand Up @@ -135,7 +135,7 @@ def pad_input_tokens(
if not input_ids or not mm_inputs.mm_items:
return input_ids

input_ids_tensor = torch.tensor(input_ids)
input_ids_tensor = torch.as_tensor(input_ids)

# Create mapping of token_ids to pad_values for each modality
token_to_pad_mapping = {}
Expand Down Expand Up @@ -211,7 +211,7 @@ def get_embedding_chunk(
end_index += extend_end_index - start + 1
elif extend_end_index > end:
end_index += end - start + 1
# some models embedding is 3-dim, reshape it to 2-dim
# some models' embedding is 3-dim, reshape it to 2-dim
embedding = embedding.reshape(-1, embedding.shape[-1])
embedding_chunk = embedding[start_index:end_index]
return embedding_chunk, start_index, end_index
Expand Down Expand Up @@ -428,7 +428,7 @@ def embed_mm_inputs(
modality_id = modality.name.lower()
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
if len(items) != 0 and embedder is not None:
placeholder_tensor = torch.tensor(
placeholder_tensor = torch.as_tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
Expand Down Expand Up @@ -473,11 +473,9 @@ def embed_mm_inputs(
for embedding, mask in zip(embeddings, masks):
if embedding is None or mask is None:
continue
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(
mask,
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
)
# in-place update
indices = torch.where(mask.squeeze(dim=-1))[0]
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
return inputs_embeds


Expand Down Expand Up @@ -561,51 +559,53 @@ def get_multimodal_data_bounds(
[bounds_count, 2]
"""
# All the multimodal data in the batch should share the same special bound token ids.
start_tokens = [s for s, _e in token_pairs]
end_tokens = [e for _s, e in token_pairs]
start_tokens = {s for s, _e in token_pairs}
end_tokens = {e for _s, e in token_pairs}

assert all(isinstance(t, int) for t in start_tokens)
assert all(isinstance(t, int) for t in end_tokens)

start_cond = torch.isin(
input_ids, torch.tensor(start_tokens, device=input_ids.device)
input_ids, torch.as_tensor(start_tokens, device=input_ids.device)
)
end_cond = torch.isin(
input_ids, torch.as_tensor(end_tokens, device=input_ids.device)
)
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))

(data_start_tokens,) = torch.where(start_cond)
(data_end_tokens,) = torch.where(end_cond)

data_start_tokens_cpu = data_start_tokens.cpu().tolist()
data_end_tokens_cpu = data_end_tokens.cpu().tolist()

# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
if len(data_start_tokens) != len(data_end_tokens):
if len(data_start_tokens_cpu) != len(data_end_tokens_cpu):
if (
len(data_start_tokens) + 1 == len(data_end_tokens)
and input_ids[0] in pad_values
and data_end_tokens[0] < data_start_tokens[0]
len(data_start_tokens_cpu) + 1 == len(data_end_tokens_cpu)
and input_ids[0].item() in pad_values
and data_end_tokens_cpu
and data_start_tokens_cpu
and data_end_tokens_cpu[0] < data_start_tokens_cpu[0]
):
data_start_tokens = torch.cat(
[
torch.tensor([0], device=data_start_tokens.device),
data_start_tokens,
]
)
valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
data_start_tokens_cpu.insert(0, 0)
valid_mm_data_nums = min(len(data_start_tokens_cpu), len(data_end_tokens_cpu))

if valid_mm_data_nums == 0:
return torch.zeros((0, 2), device=input_ids.device)

# Filter out pairs where start_token >= end_token
valid_pairs = []
for i in range(valid_mm_data_nums):
start_token = data_start_tokens[i]
end_token = data_end_tokens[i]
start_token = data_start_tokens_cpu[i]
end_token = data_end_tokens_cpu[i]
if start_token < end_token:
valid_pairs.append((start_token + 1, end_token - 1))

if not valid_pairs:
return torch.zeros((0, 2), device=input_ids.device)

# Convert valid pairs to tensor
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
valid_pairs_tensor = torch.as_tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor


Expand Down Expand Up @@ -634,11 +634,7 @@ def tensor_hash(tensor_list) -> int:
tensor = tensor.float()

assert isinstance(tensor, torch.Tensor)
if tensor.is_cuda:
# TODO: improve this
tensor_cpu = tensor.cpu()
else:
tensor_cpu = tensor
tensor_cpu = tensor.cpu()

mv = memoryview(tensor_cpu.numpy())
return data_hash(mv.tobytes())
Expand Down
Loading