@@ -64,42 +64,6 @@ def apply(self, params: PoolingParams) -> None:
6464 params .requires_token_ids = self .requires_token_ids
6565
6666
67- def get_prompt_lens (
68- hidden_states : torch .Tensor | list [torch .Tensor ],
69- pooling_metadata : PoolingMetadata ,
70- ) -> torch .Tensor :
71- return pooling_metadata .prompt_lens
72-
73-
74- def get_prompt_token_ids (pooling_metadata : PoolingMetadata ) -> list [torch .Tensor ]:
75- assert pooling_metadata .prompt_token_ids is not None , (
76- "Please set `requires_token_ids=True` in `get_pooling_updates`"
77- )
78-
79- return [
80- pooling_metadata .prompt_token_ids [i , :num ]
81- for i , num in enumerate (pooling_metadata .prompt_lens )
82- ]
83-
84-
85- def get_pooling_params (pooling_metadata : PoolingMetadata ) -> list [PoolingParams ]:
86- pooling_params = pooling_metadata .pooling_params
87- return pooling_params
88-
89-
90- def get_tasks (pooling_metadata : PoolingMetadata ) -> list [PoolingTask ]:
91- pooling_params = get_pooling_params (pooling_metadata )
92-
93- tasks : list [PoolingTask ] = [
94- task
95- for pooling_param in pooling_params
96- if (task := pooling_param .task ) is not None
97- ]
98- assert len (pooling_params ) == len (tasks )
99-
100- return tasks
101-
102-
10367def get_classification_activation_function (config : PretrainedConfig ):
10468 # Implement alignment with transformers ForSequenceClassificationLoss
10569 # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
@@ -466,7 +430,7 @@ def forward(
466430 pooled_data = self .projector (pooled_data )
467431 # pooled_data shape: [batchsize, embedding_dimension]
468432
469- pooling_params = get_pooling_params ( pooling_metadata )
433+ pooling_params = pooling_metadata . pooling_params
470434
471435 # for matryoshka representation
472436 dimensions_list = [pooling_param .dimensions for pooling_param in pooling_params ]
@@ -606,7 +570,7 @@ def forward(
606570 if self .logit_bias is not None :
607571 pooled_data -= self .logit_bias
608572
609- pooling_params = get_pooling_params ( pooling_metadata )
573+ pooling_params = pooling_metadata . pooling_params
610574 flags = [p .use_activation for p in pooling_params ]
611575
612576 if len (set (flags )) == 1 :
@@ -704,7 +668,7 @@ def forward(
704668 pooling_metadata : PoolingMetadata ,
705669 ) -> PoolerOutput :
706670 pooled_data = self .pooling (hidden_states , pooling_metadata )
707- pooling_params = get_pooling_params ( pooling_metadata )
671+ pooling_params = pooling_metadata . pooling_params
708672 assert len (pooled_data ) == len (pooling_params )
709673
710674 pooled_data = [self .head (d , p ) for d , p in zip (pooled_data , pooling_params )]
@@ -724,11 +688,11 @@ def extract_states(
724688 pooling_metadata : PoolingMetadata ,
725689 ) -> torch .Tensor | list [torch .Tensor ]:
726690 pooled_data_lst = self .pooling (hidden_states , pooling_metadata )
727- prompt_token_ids = get_prompt_token_ids (pooling_metadata )
691+ prompt_token_ids = pooling_metadata . get_prompt_token_ids ()
728692
729693 pooled_data = list [torch .Tensor ]()
730694
731- pooling_params = get_pooling_params ( pooling_metadata )
695+ pooling_params = pooling_metadata . pooling_params
732696
733697 for data , token_id , pooling_param in zip (
734698 pooled_data_lst , prompt_token_ids , pooling_params
@@ -757,7 +721,7 @@ def forward(
757721 pooling_metadata : PoolingMetadata ,
758722 ) -> PoolerOutput :
759723 pooled_data = self .extract_states (hidden_states , pooling_metadata )
760- pooling_params = get_pooling_params ( pooling_metadata )
724+ pooling_params = pooling_metadata . pooling_params
761725 assert len (pooled_data ) == len (pooling_params )
762726
763727 pooled_data = [self .head (d , p ) for d , p in zip (pooled_data , pooling_params )]
@@ -794,7 +758,7 @@ def forward(
794758
795759 outputs = list [torch .Tensor ]()
796760 offset = 0
797- for task , group in groupby (get_tasks ( pooling_metadata ) ):
761+ for task , group in groupby (pooling_metadata . tasks ):
798762 if not (pooler := poolers_by_task .get (task )):
799763 raise ValueError (
800764 f"Unsupported task: { task } "
0 commit comments