Skip to content

Commit 49fe3f7

Browse files
committed
feat: inputs_embeds for new models
1 parent 29525e1 commit 49fe3f7

File tree

3 files changed

+32
-14
lines changed

3 files changed

+32
-14
lines changed

vllm/model_executor/models/exaone.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@
5454
from vllm.utils import is_hip
5555

5656
from .interfaces import SupportsLoRA
57-
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
57+
from .utils import (PPMissingLayer, get_inputs_embeds, is_pp_missing_parameter,
58+
make_layers)
5859

5960

6061
class ExaoneGatedMLP(nn.Module):
@@ -365,12 +366,13 @@ def forward(
365366
attn_metadata: AttentionMetadata,
366367
intermediate_tensors: Optional[IntermediateTensors],
367368
inputs_embeds: Optional[torch.Tensor] = None,
369+
inputs_embeds_masks: Optional[torch.Tensor] = None,
368370
) -> Union[torch.Tensor, IntermediateTensors]:
369371
if get_pp_group().is_first_rank:
370-
if inputs_embeds is not None:
371-
hidden_states = inputs_embeds
372-
else:
373-
hidden_states = self.get_input_embeddings(input_ids)
372+
hidden_states = get_inputs_embeds(input_ids,
373+
self.get_input_embeddings,
374+
inputs_embeds,
375+
inputs_embeds_masks)
374376
residual = None
375377
else:
376378
assert intermediate_tensors is not None
@@ -484,9 +486,12 @@ def forward(
484486
kv_caches: List[torch.Tensor],
485487
attn_metadata: AttentionMetadata,
486488
intermediate_tensors: Optional[IntermediateTensors] = None,
489+
inputs_embeds: Optional[torch.Tensor] = None,
490+
inputs_embeds_masks: Optional[torch.Tensor] = None,
487491
) -> Union[torch.Tensor, IntermediateTensors]:
488492
model_output = self.transformer(input_ids, positions, kv_caches,
489-
attn_metadata, intermediate_tensors)
493+
attn_metadata, intermediate_tensors,
494+
inputs_embeds, inputs_embeds_masks)
490495
return model_output
491496

492497
def compute_logits(

vllm/model_executor/models/granite.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@
5252
from vllm.utils import is_hip
5353

5454
from .interfaces import SupportsLoRA
55-
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
55+
from .utils import (PPMissingLayer, get_inputs_embeds, is_pp_missing_parameter,
56+
make_layers)
5657

5758

5859
class GraniteMLP(nn.Module):
@@ -304,12 +305,13 @@ def forward(
304305
attn_metadata: AttentionMetadata,
305306
intermediate_tensors: Optional[IntermediateTensors],
306307
inputs_embeds: Optional[torch.Tensor] = None,
308+
inputs_embeds_masks: Optional[torch.Tensor] = None,
307309
) -> Union[torch.Tensor, IntermediateTensors]:
308310
if get_pp_group().is_first_rank:
309-
if inputs_embeds is not None:
310-
hidden_states = inputs_embeds
311-
else:
312-
hidden_states = self.get_input_embeddings(input_ids)
311+
hidden_states = get_inputs_embeds(input_ids,
312+
self.get_input_embeddings,
313+
inputs_embeds,
314+
inputs_embeds_masks)
313315
residual = None
314316
else:
315317
assert intermediate_tensors is not None
@@ -418,9 +420,12 @@ def forward(
418420
kv_caches: List[torch.Tensor],
419421
attn_metadata: AttentionMetadata,
420422
intermediate_tensors: Optional[IntermediateTensors] = None,
423+
inputs_embeds: Optional[torch.Tensor] = None,
424+
inputs_embeds_masks: Optional[torch.Tensor] = None,
421425
) -> Union[torch.Tensor, IntermediateTensors]:
422426
model_output = self.model(input_ids, positions, kv_caches,
423-
attn_metadata, intermediate_tensors)
427+
attn_metadata, intermediate_tensors,
428+
inputs_embeds, inputs_embeds_masks)
424429
return model_output
425430

426431
def compute_logits(

vllm/model_executor/models/phimoe.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from vllm.sequence import IntermediateTensors
4848

4949
from .interfaces import SupportsLoRA
50+
from .utils import get_inputs_embeds
5051

5152

5253
class PhiMoEConfig(PretrainedConfig):
@@ -462,8 +463,12 @@ def forward(
462463
positions: torch.Tensor,
463464
kv_caches: List[torch.Tensor],
464465
attn_metadata: AttentionMetadata,
466+
inputs_embeds: Optional[torch.Tensor] = None,
467+
inputs_embeds_masks: Optional[torch.Tensor] = None,
465468
) -> torch.Tensor:
466-
hidden_states = self.embed_tokens(input_ids)
469+
hidden_states = get_inputs_embeds(input_ids, self.embed_tokens,
470+
inputs_embeds, inputs_embeds_masks)
471+
467472
residual = None
468473
for i in range(len(self.layers)):
469474
layer = self.layers[i]
@@ -540,9 +545,12 @@ def forward(
540545
kv_caches: List[torch.Tensor],
541546
attn_metadata: AttentionMetadata,
542547
intermediate_tensors: Optional[IntermediateTensors] = None,
548+
inputs_embeds: Optional[torch.Tensor] = None,
549+
inputs_embeds_masks: Optional[torch.Tensor] = None,
543550
) -> torch.Tensor:
544551
hidden_states = self.model(input_ids, positions, kv_caches,
545-
attn_metadata)
552+
attn_metadata, inputs_embeds,
553+
inputs_embeds_masks)
546554
return hidden_states
547555

548556
def compute_logits(self, hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)