@@ -239,7 +239,7 @@ def __init__(
239239 f"a `torch.dtype` (e.g., 'float32'), but got { dtype } ."
240240 )
241241 # Disable caching if gradient checkpointing is enabled (not supported)
242- config = AutoConfig .from_pretrained (model_id )
242+ config = AutoConfig .from_pretrained (model_id , trust_remote_code = self . args . trust_remote_code )
243243 architecture = getattr (transformers , config .architectures [0 ])
244244 model = architecture .from_pretrained (model_id , ** model_init_kwargs )
245245 else :
@@ -263,7 +263,9 @@ def __init__(
263263
264264 # Processing class
265265 if processing_class is None :
266- processing_class = AutoProcessor .from_pretrained (model .config ._name_or_path )
266+ processing_class = AutoProcessor .from_pretrained (
267+ model .config ._name_or_path , trust_remote_code = self .args .trust_remote_code
268+ )
267269
268270 # Handle pad token for processors or tokenizers
269271 if isinstance (processing_class , ProcessorMixin ):
@@ -427,7 +429,7 @@ def __init__(
427429 self .ref_model = None
428430 else :
429431 # For deepspeed, fsdp or non-distributed models, create a reference model from scratch
430- config = AutoConfig .from_pretrained (model_id )
432+ config = AutoConfig .from_pretrained (model_id , trust_remote_code = self . args . trust_remote_code )
431433 architecture = getattr (transformers , config .architectures [0 ])
432434 self .ref_model = architecture .from_pretrained (model_id , ** model_init_kwargs )
433435
@@ -537,6 +539,7 @@ def __init__(
537539 max_num_batched_tokens = 4096 ,
538540 model_impl = self .args .vllm_model_impl ,
539541 enable_sleep_mode = self .args .vllm_enable_sleep_mode ,
542+ trust_remote_code = self .args .trust_remote_code ,
540543 )
541544 if self .args .vllm_enable_sleep_mode :
542545 self .llm .sleep (level = 1 )
0 commit comments