1212from vllm .config import VllmConfig
1313from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
1414 RowParallelLinear )
15- from vllm .model_executor .layers .pooler import Pooler , PoolingType
15+ from vllm .model_executor .layers .pooler import Pooler , PoolingType , SimplePooler
1616from vllm .model_executor .pooling_metadata import PoolingMetadata
1717from vllm .sequence import IntermediateTensors , PoolerOutput
1818
@@ -32,7 +32,7 @@ def forward(self, input):
3232 return self .activation (input )
3333
3434
35- class Qwen2ForRewardModel (nn .Module , SupportsLoRA , SupportsPP ):
35+ class Qwen2RewardBaseModel (nn .Module , SupportsLoRA , SupportsPP ):
3636 packed_modules_mapping = {
3737 "qkv_proj" : [
3838 "q_proj" ,
@@ -60,7 +60,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
6060 config = vllm_config .model_config .hf_config
6161 quant_config = vllm_config .quant_config
6262 lora_config = vllm_config .lora_config
63- pooler_config = vllm_config .model_config .pooler_config
6463
6564 self .config = config
6665 self .lora_config = lora_config
@@ -74,14 +73,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
7473 config .hidden_size ,
7574 quant_config = quant_config ),
7675 ReLU (),
77- RowParallelLinear (config .hidden_size , 1 ,
76+ RowParallelLinear (config .hidden_size ,
77+ config .num_labels ,
7878 quant_config = quant_config ),
7979 )
80- self ._pooler = Pooler .from_config_with_defaults (
81- pooler_config ,
82- pooling_type = PoolingType .ALL ,
83- normalize = False ,
84- softmax = False )
80+ self ._pooler : SimplePooler
8581 self .make_empty_intermediate_tensors = (
8682 self .model .make_empty_intermediate_tensors )
8783
@@ -115,3 +111,31 @@ def load_weights(self, weights: Iterable[Tuple[str,
115111 loader = AutoWeightsLoader (self ,
116112 ignore_unexpected_prefixes = ["lm_head." ])
117113 return loader .load_weights (weights )
114+
115+
116+ class Qwen2ForRewardModel (Qwen2RewardBaseModel ):
117+
118+ def __init__ (self , * , vllm_config , prefix = "" ):
119+ vllm_config .model_config .hf_config .num_labels = 1
120+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
121+ pooler_config = vllm_config .model_config .pooler_config
122+ self ._pooler = Pooler .from_config_with_defaults (
123+ pooler_config ,
124+ pooling_type = PoolingType .ALL ,
125+ normalize = False ,
126+ softmax = False )
127+
128+
129+ class Qwen2ForProcessRewardModel (Qwen2RewardBaseModel ):
130+
131+ def __init__ (self , * , vllm_config , prefix = "" ):
132+ vllm_config .model_config .hf_config .num_labels = 2
133+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
134+ pooler_config = vllm_config .model_config .pooler_config
135+ self ._pooler = Pooler .from_config_with_defaults (
136+ pooler_config ,
137+ pooling_type = PoolingType .STEP ,
138+ normalize = False ,
139+ softmax = True ,
140+ step_tag_id = 151651 ,
141+ )
0 commit comments