1- import os
1+ import logging
22from dataclasses import dataclass
3- from typing import Dict , Optional
3+ from typing import Optional
44
55import torch
66from torch import nn , Tensor
1111
1212from tevatron .reranker .arguments import ModelArguments
1313
14- import logging
15-
1614logger = logging .getLogger (__name__ )
1715
1816
@@ -27,6 +25,7 @@ class RerankerModel(nn.Module):
2725
2826 def __init__ (self , hf_model : PreTrainedModel , train_batch_size : int = None ):
2927 super ().__init__ ()
28+ logger .info (f"Initializing RerankerModel with train_batch_size: { train_batch_size } " )
3029 self .config = hf_model .config
3130 self .hf_model = hf_model
3231 self .train_batch_size = train_batch_size
@@ -36,54 +35,52 @@ def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
3635 'target_label' ,
3736 torch .zeros (self .train_batch_size , dtype = torch .long , device = self .hf_model .device )
3837 )
39- for name , param in self .hf_model .named_parameters ():
40- # for some reason, ds zero 3 left some weights empty
41- if 'modules_to_save' in name and param .numel () == 0 :
42- logger .warning (f'parameter { name } , shape { param .shape } is empty' )
43- param .data = nn .Linear (self .hf_model .config .hidden_size , 1 ).weight .data
44- logger .warning ('{} data: {}' .format (name , param .data .cpu ().numpy ()))
38+ logger .info (f"RerankerModel initialized with config: { self .config } " )
4539
46- def forward (self , pair : Dict [str , Tensor ] = None ):
47- ranker_logits = self .hf_model (** pair , return_dict = True ).logits
48- if self .train_batch_size :
49- grouped_logits = ranker_logits .view (self .train_batch_size , - 1 )
50- loss = self .cross_entropy (grouped_logits , self .target_label )
51- return RerankerOutput (
52- loss = loss ,
53- scores = ranker_logits
54- )
40+ def forward (self , input_ids : Tensor = None , attention_mask : Tensor = None , labels : Tensor = None , ** kwargs ):
41+ logger .debug (f"Forward pass with input shape: { input_ids .shape if input_ids is not None else 'None' } " )
42+ outputs = self .hf_model (input_ids = input_ids , attention_mask = attention_mask , ** kwargs )
43+
44+ if labels is not None :
45+ loss = self .cross_entropy (outputs .logits .view (self .train_batch_size , - 1 ), labels )
46+ logger .debug (f"Computed loss: { loss .item ()} " )
47+ else :
48+ loss = None
49+ logger .debug ("No labels provided, skipping loss computation" )
5550
5651 return RerankerOutput (
57- loss = None ,
58- scores = ranker_logits
52+ loss = loss ,
53+ scores = outputs . logits
5954 )
6055
61- def gradient_checkpointing_enable (self , ** kwargs ):
62- return False
63- # self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)
64-
6556 @classmethod
6657 def build (
6758 cls ,
6859 model_args : ModelArguments ,
6960 train_args : TrainingArguments ,
7061 ** hf_kwargs ,
7162 ):
63+ logger .info (f"Building RerankerModel with args: { model_args } " )
7264 base_model = cls .TRANSFORMER_CLS .from_pretrained (
7365 model_args .model_name_or_path ,
7466 ** hf_kwargs ,
7567 )
7668 if base_model .config .pad_token_id is None :
7769 base_model .config .pad_token_id = 0
70+ logger .info ("Set pad_token_id to 0" )
71+
7872 if model_args .lora or model_args .lora_name_or_path :
73+ logger .info ("Applying LoRA" )
7974 if train_args .gradient_checkpointing :
8075 base_model .enable_input_require_grads ()
8176 if model_args .lora_name_or_path :
77+ logger .info (f"Loading LoRA from { model_args .lora_name_or_path } " )
8278 lora_config = LoraConfig .from_pretrained (model_args .lora_name_or_path , ** hf_kwargs )
8379 lora_model = PeftModel .from_pretrained (base_model , model_args .lora_name_or_path ,
8480 torch_dtype = torch .bfloat16 ,
8581 attn_implementation = "flash_attention_2" )
8682 else :
83+ logger .info ("Initializing new LoRA" )
8784 lora_config = LoraConfig (
8885 base_model_name_or_path = model_args .model_name_or_path ,
8986 task_type = TaskType .SEQ_CLS ,
@@ -99,6 +96,7 @@ def build(
9996 train_batch_size = train_args .per_device_train_batch_size ,
10097 )
10198 else :
99+ logger .info ("Building model without LoRA" )
102100 model = cls (
103101 hf_model = base_model ,
104102 train_batch_size = train_args .per_device_train_batch_size ,
@@ -110,23 +108,28 @@ def load(cls,
110108 model_name_or_path : str ,
111109 lora_name_or_path : str = None ,
112110 ** hf_kwargs ):
111+ logger .info (f"Loading RerankerModel from { model_name_or_path } " )
113112 base_model = cls .TRANSFORMER_CLS .from_pretrained (model_name_or_path , num_labels = 1 , ** hf_kwargs ,
114113 torch_dtype = torch .bfloat16 ,
115114 attn_implementation = "flash_attention_2" )
116115 if base_model .config .pad_token_id is None :
117116 base_model .config .pad_token_id = 0
117+ logger .info ("Set pad_token_id to 0" )
118118 if lora_name_or_path :
119+ logger .info (f"Loading LoRA from { lora_name_or_path } " )
119120 lora_config = LoraConfig .from_pretrained (lora_name_or_path , ** hf_kwargs )
120121 lora_model = PeftModel .from_pretrained (base_model , lora_name_or_path , config = lora_config )
121122 lora_model = lora_model .merge_and_unload ()
122123 model = cls (
123124 hf_model = lora_model ,
124125 )
125126 else :
127+ logger .info ("Loading model without LoRA" )
126128 model = cls (
127129 hf_model = base_model ,
128130 )
129131 return model
130132
131133 def save (self , output_dir : str ):
132- self .hf_model .save_pretrained (output_dir )
134+ logger .info (f"Saving model to { output_dir } " )
135+ self .hf_model .save_pretrained (output_dir )
0 commit comments