Skip to content

Commit 68eb5c8

Browse files
[Misc] Move functions into PoolingMetadata (vllm-project#30027)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 5430e11 commit 68eb5c8

3 files changed

Lines changed: 30 additions & 47 deletions

File tree

vllm/model_executor/layers/pooler.py

Lines changed: 7 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -64,42 +64,6 @@ def apply(self, params: PoolingParams) -> None:
6464
params.requires_token_ids = self.requires_token_ids
6565

6666

67-
def get_prompt_lens(
68-
hidden_states: torch.Tensor | list[torch.Tensor],
69-
pooling_metadata: PoolingMetadata,
70-
) -> torch.Tensor:
71-
return pooling_metadata.prompt_lens
72-
73-
74-
def get_prompt_token_ids(pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
75-
assert pooling_metadata.prompt_token_ids is not None, (
76-
"Please set `requires_token_ids=True` in `get_pooling_updates`"
77-
)
78-
79-
return [
80-
pooling_metadata.prompt_token_ids[i, :num]
81-
for i, num in enumerate(pooling_metadata.prompt_lens)
82-
]
83-
84-
85-
def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
86-
pooling_params = pooling_metadata.pooling_params
87-
return pooling_params
88-
89-
90-
def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
91-
pooling_params = get_pooling_params(pooling_metadata)
92-
93-
tasks: list[PoolingTask] = [
94-
task
95-
for pooling_param in pooling_params
96-
if (task := pooling_param.task) is not None
97-
]
98-
assert len(pooling_params) == len(tasks)
99-
100-
return tasks
101-
102-
10367
def get_classification_activation_function(config: PretrainedConfig):
10468
# Implement alignment with transformers ForSequenceClassificationLoss
10569
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
@@ -466,7 +430,7 @@ def forward(
466430
pooled_data = self.projector(pooled_data)
467431
# pooled_data shape: [batchsize, embedding_dimension]
468432

469-
pooling_params = get_pooling_params(pooling_metadata)
433+
pooling_params = pooling_metadata.pooling_params
470434

471435
# for matryoshka representation
472436
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
@@ -606,7 +570,7 @@ def forward(
606570
if self.logit_bias is not None:
607571
pooled_data -= self.logit_bias
608572

609-
pooling_params = get_pooling_params(pooling_metadata)
573+
pooling_params = pooling_metadata.pooling_params
610574
flags = [p.use_activation for p in pooling_params]
611575

612576
if len(set(flags)) == 1:
@@ -704,7 +668,7 @@ def forward(
704668
pooling_metadata: PoolingMetadata,
705669
) -> PoolerOutput:
706670
pooled_data = self.pooling(hidden_states, pooling_metadata)
707-
pooling_params = get_pooling_params(pooling_metadata)
671+
pooling_params = pooling_metadata.pooling_params
708672
assert len(pooled_data) == len(pooling_params)
709673

710674
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
@@ -724,11 +688,11 @@ def extract_states(
724688
pooling_metadata: PoolingMetadata,
725689
) -> torch.Tensor | list[torch.Tensor]:
726690
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
727-
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
691+
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
728692

729693
pooled_data = list[torch.Tensor]()
730694

731-
pooling_params = get_pooling_params(pooling_metadata)
695+
pooling_params = pooling_metadata.pooling_params
732696

733697
for data, token_id, pooling_param in zip(
734698
pooled_data_lst, prompt_token_ids, pooling_params
@@ -757,7 +721,7 @@ def forward(
757721
pooling_metadata: PoolingMetadata,
758722
) -> PoolerOutput:
759723
pooled_data = self.extract_states(hidden_states, pooling_metadata)
760-
pooling_params = get_pooling_params(pooling_metadata)
724+
pooling_params = pooling_metadata.pooling_params
761725
assert len(pooled_data) == len(pooling_params)
762726

763727
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
@@ -794,7 +758,7 @@ def forward(
794758

795759
outputs = list[torch.Tensor]()
796760
offset = 0
797-
for task, group in groupby(get_tasks(pooling_metadata)):
761+
for task, group in groupby(pooling_metadata.tasks):
798762
if not (pooler := poolers_by_task.get(task)):
799763
raise ValueError(
800764
f"Unsupported task: {task} "

vllm/model_executor/models/gritlm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
PoolerHead,
1515
PoolerNormalize,
1616
PoolingParamsUpdate,
17-
get_prompt_lens,
18-
get_prompt_token_ids,
1917
)
2018
from vllm.model_executor.models.llama import LlamaForCausalLM
2119
from vllm.tasks import PoolingTask
@@ -153,11 +151,11 @@ def forward(
153151
hidden_states: torch.Tensor | list[torch.Tensor],
154152
pooling_metadata: PoolingMetadata,
155153
) -> list[torch.Tensor] | torch.Tensor:
156-
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
154+
prompt_lens = pooling_metadata.prompt_lens
157155
instr_lens = torch.tensor(
158156
[
159157
self._get_instruction_len(token_ids.cpu().numpy())
160-
for token_ids in get_prompt_token_ids(pooling_metadata)
158+
for token_ids in pooling_metadata.get_prompt_token_ids()
161159
],
162160
device="cpu",
163161
)

vllm/v1/pool/metadata.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
from vllm.pooling_params import PoolingParams
8+
from vllm.tasks import PoolingTask
89
from vllm.utils.platform_utils import is_pin_memory_available
910

1011
pin_memory = is_pin_memory_available()
@@ -40,6 +41,18 @@ class PoolingMetadata:
4041
pooling_params: list[PoolingParams]
4142
pooling_cursor: PoolingCursor | None = None
4243

44+
def __post_init__(self) -> None:
45+
pooling_params = self.pooling_params
46+
47+
tasks: list[PoolingTask] = [
48+
task
49+
for pooling_param in pooling_params
50+
if (task := pooling_param.task) is not None
51+
]
52+
assert len(pooling_params) == len(tasks)
53+
54+
self.tasks = tasks
55+
4356
def __getitem__(self, indices: slice):
4457
return PoolingMetadata(
4558
prompt_lens=self.prompt_lens[indices],
@@ -52,6 +65,14 @@ def __getitem__(self, indices: slice):
5265
else self.pooling_cursor[indices],
5366
)
5467

68+
def get_prompt_token_ids(self) -> list[torch.Tensor]:
69+
prompt_token_ids = self.prompt_token_ids
70+
assert prompt_token_ids is not None, (
71+
"Please set `requires_token_ids=True` in `get_pooling_updates`"
72+
)
73+
74+
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
75+
5576
def build_pooling_cursor(
5677
self, num_scheduled_tokens: list[int], device: torch.device
5778
):

0 commit comments

Comments
 (0)