@@ -154,7 +154,7 @@ def create_device_model(self):
154154 model = AutoModelForImageTextToText .from_pretrained (
155155 pretrained_model_name_or_path = self .config .path ,
156156 trust_remote_code = True ,
157- torch_dtype = dtype ,
157+ dtype = dtype ,
158158 attn_implementation = self .config .attn_impl ,
159159 )
160160 if self .config .disable_dropout :
@@ -179,40 +179,35 @@ def create_device_model(self):
179179
180180 def _create_llm_actor_or_critic (self ):
181181 dtype = getattr (torch , self .config .dtype )
182- if not self .config .is_critic :
183- if self .config .init_from_scratch :
184- # initialize model from config
185- # NOTE: VLM cannot directly load state dict using this
186- # random initialized model, so otherwise we call
187- # from_pretrained rather than loading weights into this random model.
188- model = AutoModelForCausalLM .from_config (
189- self .model_config ,
190- torch_dtype = dtype ,
191- attn_implementation = self .config .attn_impl ,
192- )
193- else :
194- model = AutoModelForCausalLM .from_pretrained (
195- pretrained_model_name_or_path = self .config .path ,
196- trust_remote_code = True ,
197- torch_dtype = dtype ,
198- attn_implementation = self .config .attn_impl ,
199- )
182+
183+ if self .config .is_critic :
184+ model_class = AutoModelForTokenClassification
185+ model_kwargs = {"num_labels" : 1 }
200186 else :
201- if self .config .init_from_scratch :
202- model = AutoModelForTokenClassification .from_config (
203- self .model_config ,
204- torch_dtype = dtype ,
205- num_labels = 1 ,
206- attn_implementation = self .config .attn_impl ,
207- )
208- else :
209- model = AutoModelForTokenClassification .from_pretrained (
210- pretrained_model_name_or_path = self .config .path ,
211- trust_remote_code = True ,
212- torch_dtype = dtype ,
213- num_labels = 1 ,
214- attn_implementation = self .config .attn_impl ,
215- )
187+ model_class = AutoModelForCausalLM
188+ model_kwargs = {}
189+
190+ common_kwargs = {
191+ "dtype" : dtype ,
192+ "attn_implementation" : self .config .attn_impl ,
193+ }
194+ model_kwargs .update (common_kwargs )
195+
196+ if self .config .init_from_scratch :
197+ # initialize model from config
198+ # NOTE: VLM cannot directly load state dict using this
199+ # random initialized model, so otherwise we call
200+ # from_pretrained rather than loading weights into this random model.
201+ model = model_class .from_config (
202+ self .model_config ,
203+ ** model_kwargs ,
204+ )
205+ else :
206+ model = model_class .from_pretrained (
207+ pretrained_model_name_or_path = self .config .path ,
208+ trust_remote_code = True ,
209+ ** model_kwargs ,
210+ )
216211 return model
217212
218213 def destroy (self ):
0 commit comments