Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
dfebf51
[Bugfix] Merge multimodal embeddings by `is_embed` mask instead of to…
DarkLight1337 Apr 8, 2025
437dacd
Rename
DarkLight1337 Apr 8, 2025
bbe7096
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Apr 9, 2025
57e9f03
Use #16007
DarkLight1337 Apr 9, 2025
d5c9555
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Aug 27, 2025
e08deaa
Fix
DarkLight1337 Aug 27, 2025
302b2c5
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Aug 27, 2025
6a1307f
Update
DarkLight1337 Aug 28, 2025
3a4740a
Fix
DarkLight1337 Aug 28, 2025
68c54d8
Draft
DarkLight1337 Aug 28, 2025
6ddc91e
Fix device
DarkLight1337 Aug 28, 2025
28cc8cb
Persistent buffer
DarkLight1337 Aug 28, 2025
c335908
Avoid unnecessary initialization
DarkLight1337 Aug 28, 2025
cbb70ea
Fix reset
DarkLight1337 Aug 28, 2025
76f2925
Update
DarkLight1337 Aug 28, 2025
b6e8775
Simplify
DarkLight1337 Aug 28, 2025
fee0d27
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Aug 30, 2025
f71a40b
Use padded tokens
DarkLight1337 Sep 1, 2025
3af1bdb
Fix wrong device
DarkLight1337 Sep 1, 2025
003800e
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 1, 2025
975569d
Debug
DarkLight1337 Sep 2, 2025
8d6b6c4
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 18, 2025
c001581
Fix?
DarkLight1337 Sep 18, 2025
9e4512c
Simplify the code
DarkLight1337 Sep 18, 2025
e002d44
Reduce diffs
DarkLight1337 Sep 18, 2025
1934f25
Avoid intermediate variable
DarkLight1337 Sep 18, 2025
573cb4b
Standardize input embeddings logic
DarkLight1337 Sep 18, 2025
fa5e688
Cleanup
DarkLight1337 Sep 18, 2025
0799fdb
Fix
DarkLight1337 Sep 18, 2025
7f58edc
Fix
DarkLight1337 Sep 18, 2025
1e9ec64
Comment out debug path
DarkLight1337 Sep 18, 2025
439b264
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 19, 2025
a9f7e84
fix tpu recompilations
NickLucche Sep 19, 2025
29e0ad5
Remove sanity check for code simplicity
DarkLight1337 Sep 19, 2025
9a6768e
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 19, 2025
f6e7e62
Update interface for all MM models
DarkLight1337 Sep 19, 2025
74a4d5f
Avoid circular import
DarkLight1337 Sep 19, 2025
6d3a733
Fix `get_input_embeddings`
DarkLight1337 Sep 20, 2025
d30a4a6
Improve logging for unimpl methods
DarkLight1337 Sep 20, 2025
ad27e91
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 20, 2025
028aedf
More fixes
DarkLight1337 Sep 20, 2025
38058d1
Fix
DarkLight1337 Sep 20, 2025
a71a832
Fix
DarkLight1337 Sep 20, 2025
3d4495a
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 20, 2025
7d8f58d
Fix V0
DarkLight1337 Sep 20, 2025
e33a195
Rename `do_language_embed_multimodal -> handle_oov_mm_token`
DarkLight1337 Sep 21, 2025
ead536d
Update docstring
DarkLight1337 Sep 21, 2025
6db35c3
Add doc
DarkLight1337 Sep 21, 2025
7dc2675
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 22, 2025
d13fca8
Update DotsOCR
DarkLight1337 Sep 22, 2025
beb9df0
Fix wrong condition
DarkLight1337 Sep 22, 2025
8a6fb1b
fix qwen3-vl
ywang96 Sep 22, 2025
2eefc2d
Fix wrong condition
DarkLight1337 Sep 22, 2025
b79860e
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 22, 2025
7769ec1
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 24, 2025
aa67033
Reduce diff
DarkLight1337 Sep 24, 2025
3656239
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 26, 2025
9260170
Simplify
DarkLight1337 Sep 26, 2025
2ac91b6
Fix doc
DarkLight1337 Sep 26, 2025
3033297
Merge branch 'main' into rm-merge-mm-embeddings
DarkLight1337 Sep 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.model_executor.models.utils import _merge_multimodal_embeddings
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
Expand Down Expand Up @@ -157,6 +158,7 @@ def propose(
next_token_ids: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
is_mm_embed: Optional[torch.Tensor] = None,
mm_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
Expand Down Expand Up @@ -196,18 +198,22 @@ def propose(
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
if self.is_multimodal_model:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = self.model.get_input_embeddings(
input_ids,
multimodal_embeddings=mm_embeds or None,

if mm_embeds:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes the code format consistent with model runner

assert is_mm_embed is not None

inputs_embeds_scheduled = _merge_multimodal_embeddings(
self.input_ids[:num_tokens],
is_mm_embed,
multimodal_embeddings=mm_embeds,
)
self.inputs_embeds[:num_tokens] = inputs_embeds
inputs_embeds = self.inputs_embeds[:num_input_tokens]
self.inputs_embeds[:num_tokens] = inputs_embeds_scheduled

input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None

with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
Expand Down
52 changes: 41 additions & 11 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
supports_transcription)
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model)
from vllm.model_executor.models.utils import _merge_multimodal_embeddings
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
PlaceholderRange)
Expand Down Expand Up @@ -257,6 +258,10 @@ def __init__(
dtype=self.dtype,
device=self.device)

# Only relevant for multimodal models
self.is_mm_embed = self._make_buffer(self.max_num_tokens,
dtype=torch.bool)

# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy
Expand Down Expand Up @@ -1185,8 +1190,11 @@ def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
shift_computed_tokens: int = 0,
) -> list[torch.Tensor]:
mm_embeds: list[torch.Tensor] = []
) -> tuple[torch.Tensor, list[torch.Tensor]]:
is_mm_embed = self.is_mm_embed.cpu
mm_embeds = list[torch.Tensor]()

req_start_idx = 0
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
Expand All @@ -1195,6 +1203,7 @@ def _gather_mm_embeddings(
req_state.num_computed_tokens + shift_computed_tokens
mm_positions = req_state.mm_positions
mm_hashes = req_state.mm_hashes

for i, pos_info in enumerate(mm_positions):
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
Expand All @@ -1211,6 +1220,10 @@ def _gather_mm_embeddings(
# in the decoder's KV cache.
continue

req_start_pos = req_start_idx + start_pos
is_mm_embed[req_start_pos:req_start_pos + num_encoder_tokens] \
= True if pos_info.is_embed is None else pos_info.is_embed

start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
Expand All @@ -1231,7 +1244,13 @@ def _gather_mm_embeddings(
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
return mm_embeds

req_start_idx += num_scheduled_tokens

total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)

return self.is_mm_embed.gpu[:total_num_scheduled_tokens], mm_embeds

def get_model(self) -> nn.Module:
# get raw model out of the cudagraph wrapper.
Expand Down Expand Up @@ -1514,18 +1533,24 @@ def execute_model(
if self.supports_mm_inputs:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
is_mm_embed, mm_embeds = self._gather_mm_embeddings(
scheduler_output)
else:
mm_embeds = []
is_mm_embed, mm_embeds = torch.tensor(False), []

if self.supports_mm_inputs and get_pp_group().is_first_rank:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
inputs_embeds_scheduled = self.model.get_input_embeddings(
input_ids=self.input_ids.gpu[:num_scheduled_tokens],
multimodal_embeddings=mm_embeds or None,
)
self.input_ids.gpu[:num_scheduled_tokens])

if mm_embeds:
inputs_embeds_scheduled = _merge_multimodal_embeddings(
inputs_embeds_scheduled,
is_mm_embed,
multimodal_embeddings=mm_embeds,
)

# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(
Expand Down Expand Up @@ -1860,10 +1885,14 @@ def propose_draft_token_ids(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
mm_embeds = None

if self.supports_mm_inputs:
mm_embeds = self._gather_mm_embeddings(scheduler_output,
shift_computed_tokens=1)
is_mm_embed, mm_embeds = self._gather_mm_embeddings(
scheduler_output,
shift_computed_tokens=1,
)
else:
is_mm_embed, mm_embeds = torch.tensor(False), []

draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
Expand All @@ -1872,6 +1901,7 @@ def propose_draft_token_ids(
next_token_ids=next_token_ids,
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
is_mm_embed=is_mm_embed,
mm_embeds=mm_embeds,
)
return draft_token_ids
Expand Down
84 changes: 46 additions & 38 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.model_executor.models.interfaces import supports_transcription
from vllm.model_executor.models.interfaces_base import (
is_pooling_model, is_text_generation_model)
from vllm.model_executor.models.utils import _merge_multimodal_embeddings
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
PlaceholderRange)
Expand Down Expand Up @@ -261,6 +262,12 @@ def __init__(
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()

# Only relevant for multimodal models
self.is_mm_embed_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.bool,
device="cpu",
pin_memory=self.pin_memory)

# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
# Keep in int64 to avoid overflow with long context
Expand Down Expand Up @@ -809,31 +816,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
return per_layer_attn_metadata, logits_indices, padded_num_reqs,\
num_reqs, end_index

def _scatter_placeholders(
self,
embeds: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
if is_embed is None:
return embeds

placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders

def _gather_placeholders(
self,
placeholders: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
if is_embed is None:
return placeholders

return placeholders[is_embed]

def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
Expand Down Expand Up @@ -906,8 +888,11 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
mm_embeds: list[torch.Tensor] = []
) -> tuple[torch.Tensor, list[torch.Tensor]]:
is_mm_embed = self.is_mm_embed_cpu
mm_embeds = list[torch.Tensor]()

req_start_idx = 0
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
Expand Down Expand Up @@ -935,6 +920,10 @@ def _gather_mm_embeddings(
# in the decoder's KV cache.
continue

req_start_pos = req_start_idx + start_pos
is_mm_embed[req_start_pos:req_start_pos + num_encoder_tokens] \
= True

mm_hash = mm_hashes[i]
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
Expand All @@ -943,18 +932,33 @@ def _gather_mm_embeddings(
" be contiguous and embeddings."
encoder_output = self.encoder_cache[mm_hash]
mm_embeds.append(encoder_output)
return mm_embeds

def _get_model_inputs(self, input_ids: torch.Tensor,
mm_embeds: list[torch.Tensor]):
req_start_idx += num_scheduled_tokens

total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
is_mm_embed = is_mm_embed[:total_num_scheduled_tokens].to(self.device)

return is_mm_embed, mm_embeds

def _get_model_inputs(
self,
input_ids: torch.Tensor,
is_mm_embed: torch.Tensor,
mm_embeds: list[torch.Tensor],
):
if self.supports_mm_inputs:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
inputs_embeds = self.model.get_input_embeddings(
input_ids=input_ids,
multimodal_embeddings=mm_embeds,
)
inputs_embeds = self.model.get_input_embeddings(input_ids)

if mm_embeds:
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds,
is_mm_embed,
multimodal_embeddings=mm_embeds,
)

return None, inputs_embeds
else:
# For text-only models, we use token ids as input.
Expand Down Expand Up @@ -982,9 +986,11 @@ def execute_model(
if self.supports_mm_inputs:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
is_mm_embed, mm_embeds = self._gather_mm_embeddings(
scheduler_output)
else:
mm_embeds = []
is_mm_embed, mm_embeds = torch.tensor(False), []

xm.mark_step()
# Prepare inputs, the requests might be split into multiple
# executions, combine the result of each execution.
Expand All @@ -1001,7 +1007,7 @@ def execute_model(
attn_metadata, logits_indices, padded_num_reqs, num_reqs,\
end_index = self._prepare_inputs(scheduler_output, start_index)
input_ids, inputs_embeds = self._get_model_inputs(
self.input_ids, mm_embeds)
self.input_ids, is_mm_embed, mm_embeds)
xm.mark_step()
# Run the decoder
with set_forward_context(
Expand Down Expand Up @@ -1358,6 +1364,7 @@ def _precompile_mm_encoder(self) -> None:
placeholders_ids = placeholders_ids.to(self.device)
# Assign outputs or the graph will be cut short.
a, b = self._get_model_inputs(placeholders_ids,
torch.tensor(True),
[mm_embeds])
assert a is None
xm.mark_step()
Expand All @@ -1369,7 +1376,8 @@ def _precompile_mm_encoder(self) -> None:
dtype=torch.int32,
device="cpu")
placeholders_ids = placeholders_ids.to(self.device)
a, b = self._get_model_inputs(placeholders_ids, [])
a, b = self._get_model_inputs(placeholders_ids,
torch.tensor(False), [])
assert a is None
xm.mark_step()

Expand Down
Loading