diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index 3af62b2885e5..a376d2cb340c 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, Optional, Tuple, Union import torch import torch.nn as nn -from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.models.gemma2 import Gemma2Model @@ -37,16 +36,12 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( input_ids, positions, - kv_caches, - attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, )