@@ -34,6 +34,7 @@ def __init__(
3434 words_splitter : Optional [Union [str , WordsSplitter ]] = None ,
3535 data_processor : Optional [Union [SpanProcessor , TokenProcessor ]] = None ,
3636 encoder_from_pretrained : bool = True ,
37+ cache_dir : Optional [Union [str , Path ]] = None ,
3738 ):
3839 """
3940 Initialize the GLiNER model.
@@ -50,19 +51,19 @@ def __init__(
5051 self .config = config
5152
5253 if tokenizer is None and data_processor is None :
53- tokenizer = AutoTokenizer .from_pretrained (config .model_name )
54+ tokenizer = AutoTokenizer .from_pretrained (config .model_name , cache_dir = cache_dir )
5455
5556 if words_splitter is None and data_processor is None :
5657 words_splitter = WordsSplitter (config .words_splitter_type )
5758
5859 if config .span_mode == "token_level" :
5960 if model is None :
60- self .model = TokenModel (config , encoder_from_pretrained )
61+ self .model = TokenModel (config , encoder_from_pretrained , cache_dir = cache_dir )
6162 else :
6263 self .model = model
6364 if data_processor is None :
6465 if config .labels_encoder is not None :
65- labels_tokenizer = AutoTokenizer .from_pretrained (config .labels_encoder )
66+ labels_tokenizer = AutoTokenizer .from_pretrained (config .labels_encoder , cache_dir = cache_dir )
6667 self .data_processor = TokenBiEncoderProcessor (config , tokenizer , words_splitter , labels_tokenizer )
6768 else :
6869 self .data_processor = TokenProcessor (config , tokenizer , words_splitter )
@@ -72,12 +73,12 @@ def __init__(
7273 self .decoder = TokenDecoder (config )
7374 else :
7475 if model is None :
75- self .model = SpanModel (config , encoder_from_pretrained )
76+ self .model = SpanModel (config , encoder_from_pretrained , cache_dir = cache_dir )
7677 else :
7778 self .model = model
7879 if data_processor is None :
7980 if config .labels_encoder is not None :
80- labels_tokenizer = AutoTokenizer .from_pretrained (config .labels_encoder )
81+ labels_tokenizer = AutoTokenizer .from_pretrained (config .labels_encoder , cache_dir = cache_dir )
8182 self .data_processor = SpanBiEncoderProcessor (config , tokenizer , words_splitter , labels_tokenizer )
8283 else :
8384 self .data_processor = SpanProcessor (config , tokenizer , words_splitter )
@@ -778,10 +779,10 @@ def _from_pretrained(
778779 config_file = Path (model_dir ) / "gliner_config.json"
779780
780781 if load_tokenizer :
781- tokenizer = AutoTokenizer .from_pretrained (model_dir )
782+ tokenizer = AutoTokenizer .from_pretrained (model_dir , cache_dir = cache_dir )
782783 else :
783784 if os .path .exists (os .path .join (model_dir , "tokenizer_config.json" )):
784- tokenizer = AutoTokenizer .from_pretrained (model_dir )
785+ tokenizer = AutoTokenizer .from_pretrained (model_dir , cache_dir = cache_dir )
785786 else :
786787 tokenizer = None
787788 with open (config_file , "r" ) as f :
@@ -801,7 +802,7 @@ def _from_pretrained(
801802 add_tokens = ["[FLERT]" , config .ent_token , config .sep_token ]
802803
803804 if not load_onnx_model :
804- gliner = cls (config , tokenizer = tokenizer , encoder_from_pretrained = False )
805+ gliner = cls (config , tokenizer = tokenizer , encoder_from_pretrained = False , cache_dir = cache_dir )
805806 # to be able to load GLiNER models from previous version
806807 if (
807808 config .class_token_index == - 1 or config .vocab_size == - 1
0 commit comments