|
54 | 54 | from vllm.utils import is_hip |
55 | 55 |
|
56 | 56 | from .interfaces import SupportsLoRA |
57 | | -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers |
| 57 | +from .utils import (PPMissingLayer, get_inputs_embeds, is_pp_missing_parameter, |
| 58 | + make_layers) |
58 | 59 |
|
59 | 60 |
|
60 | 61 | class ExaoneGatedMLP(nn.Module): |
@@ -365,12 +366,13 @@ def forward( |
365 | 366 | attn_metadata: AttentionMetadata, |
366 | 367 | intermediate_tensors: Optional[IntermediateTensors], |
367 | 368 | inputs_embeds: Optional[torch.Tensor] = None, |
| 369 | + inputs_embeds_masks: Optional[torch.Tensor] = None, |
368 | 370 | ) -> Union[torch.Tensor, IntermediateTensors]: |
369 | 371 | if get_pp_group().is_first_rank: |
370 | | - if inputs_embeds is not None: |
371 | | - hidden_states = inputs_embeds |
372 | | - else: |
373 | | - hidden_states = self.get_input_embeddings(input_ids) |
| 372 | + hidden_states = get_inputs_embeds(input_ids, |
| 373 | + self.get_input_embeddings, |
| 374 | + inputs_embeds, |
| 375 | + inputs_embeds_masks) |
374 | 376 | residual = None |
375 | 377 | else: |
376 | 378 | assert intermediate_tensors is not None |
@@ -484,9 +486,12 @@ def forward( |
484 | 486 | kv_caches: List[torch.Tensor], |
485 | 487 | attn_metadata: AttentionMetadata, |
486 | 488 | intermediate_tensors: Optional[IntermediateTensors] = None, |
| 489 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 490 | + inputs_embeds_masks: Optional[torch.Tensor] = None, |
487 | 491 | ) -> Union[torch.Tensor, IntermediateTensors]: |
488 | 492 | model_output = self.transformer(input_ids, positions, kv_caches, |
489 | | - attn_metadata, intermediate_tensors) |
| 493 | + attn_metadata, intermediate_tensors, |
| 494 | + inputs_embeds, inputs_embeds_masks) |
490 | 495 | return model_output |
491 | 496 |
|
492 | 497 | def compute_logits( |
|
0 commit comments