Skip to content

Commit c15e090

Browse files
authored
Merge pull request #250 from bit2244/main
add cache_dir parameter to AutoConfig, AutoTokenizer
2 parents c716ce3 + 5e78218 commit c15e090

File tree

3 files changed

+36
-26
lines changed

3 files changed

+36
-26
lines changed

gliner/model.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
words_splitter: Optional[Union[str, WordsSplitter]] = None,
3535
data_processor: Optional[Union[SpanProcessor, TokenProcessor]] = None,
3636
encoder_from_pretrained: bool = True,
37+
cache_dir: Optional[Union[str, Path]] = None,
3738
):
3839
"""
3940
Initialize the GLiNER model.
@@ -50,19 +51,19 @@ def __init__(
5051
self.config = config
5152

5253
if tokenizer is None and data_processor is None:
53-
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
54+
tokenizer = AutoTokenizer.from_pretrained(config.model_name, cache_dir=cache_dir)
5455

5556
if words_splitter is None and data_processor is None:
5657
words_splitter = WordsSplitter(config.words_splitter_type)
5758

5859
if config.span_mode == "token_level":
5960
if model is None:
60-
self.model = TokenModel(config, encoder_from_pretrained)
61+
self.model = TokenModel(config, encoder_from_pretrained, cache_dir=cache_dir)
6162
else:
6263
self.model = model
6364
if data_processor is None:
6465
if config.labels_encoder is not None:
65-
labels_tokenizer = AutoTokenizer.from_pretrained(config.labels_encoder)
66+
labels_tokenizer = AutoTokenizer.from_pretrained(config.labels_encoder, cache_dir=cache_dir)
6667
self.data_processor = TokenBiEncoderProcessor(config, tokenizer, words_splitter, labels_tokenizer)
6768
else:
6869
self.data_processor = TokenProcessor(config, tokenizer, words_splitter)
@@ -72,12 +73,12 @@ def __init__(
7273
self.decoder = TokenDecoder(config)
7374
else:
7475
if model is None:
75-
self.model = SpanModel(config, encoder_from_pretrained)
76+
self.model = SpanModel(config, encoder_from_pretrained, cache_dir=cache_dir)
7677
else:
7778
self.model = model
7879
if data_processor is None:
7980
if config.labels_encoder is not None:
80-
labels_tokenizer = AutoTokenizer.from_pretrained(config.labels_encoder)
81+
labels_tokenizer = AutoTokenizer.from_pretrained(config.labels_encoder, cache_dir=cache_dir)
8182
self.data_processor = SpanBiEncoderProcessor(config, tokenizer, words_splitter, labels_tokenizer)
8283
else:
8384
self.data_processor = SpanProcessor(config, tokenizer, words_splitter)
@@ -778,10 +779,10 @@ def _from_pretrained(
778779
config_file = Path(model_dir) / "gliner_config.json"
779780

780781
if load_tokenizer:
781-
tokenizer = AutoTokenizer.from_pretrained(model_dir)
782+
tokenizer = AutoTokenizer.from_pretrained(model_dir, cache_dir=cache_dir)
782783
else:
783784
if os.path.exists(os.path.join(model_dir, "tokenizer_config.json")):
784-
tokenizer = AutoTokenizer.from_pretrained(model_dir)
785+
tokenizer = AutoTokenizer.from_pretrained(model_dir, cache_dir=cache_dir)
785786
else:
786787
tokenizer = None
787788
with open(config_file, "r") as f:
@@ -801,7 +802,7 @@ def _from_pretrained(
801802
add_tokens = ["[FLERT]", config.ent_token, config.sep_token]
802803

803804
if not load_onnx_model:
804-
gliner = cls(config, tokenizer=tokenizer, encoder_from_pretrained=False)
805+
gliner = cls(config, tokenizer=tokenizer, encoder_from_pretrained=False, cache_dir=cache_dir)
805806
# to be able to load GLiNER models from previous version
806807
if (
807808
config.class_token_index == -1 or config.vocab_size == -1

gliner/modeling/base.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Optional, Tuple
1+
from pathlib import Path
2+
from typing import Optional, Tuple, Union
23
from abc import ABC, abstractmethod
34
from dataclasses import dataclass
45
import warnings
@@ -80,14 +81,14 @@ def extract_prompt_features_and_word_embeddings(config, token_embeds, input_ids,
8081

8182

8283
class BaseModel(ABC, nn.Module):
83-
def __init__(self, config, from_pretrained=False):
84+
def __init__(self, config, from_pretrained = False, cache_dir: Optional[Union[str, Path]] = None):
8485
super(BaseModel, self).__init__()
8586
self.config = config
8687

8788
if not config.labels_encoder:
88-
self.token_rep_layer = Encoder(config, from_pretrained)
89+
self.token_rep_layer = Encoder(config, from_pretrained, cache_dir = cache_dir)
8990
else:
90-
self.token_rep_layer = BiEncoder(config, from_pretrained)
91+
self.token_rep_layer = BiEncoder(config, from_pretrained, cache_dir=cache_dir)
9192
if self.config.has_rnn:
9293
self.rnn = LstmSeq2SeqEncoder(config)
9394

@@ -240,12 +241,12 @@ def loss(self, x):
240241

241242

242243
class SpanModel(BaseModel):
243-
def __init__(self, config, encoder_from_pretrained):
244-
super(SpanModel, self).__init__(config, encoder_from_pretrained)
245-
self.span_rep_layer = SpanRepLayer(span_mode=config.span_mode,
246-
hidden_size=config.hidden_size,
247-
max_width=config.max_width,
248-
dropout=config.dropout)
244+
def __init__(self, config, encoder_from_pretrained, cache_dir: Optional[Union[str, Path]] = None):
245+
super(SpanModel, self).__init__(config, encoder_from_pretrained, cache_dir = cache_dir)
246+
self.span_rep_layer = SpanRepLayer(span_mode = config.span_mode,
247+
hidden_size = config.hidden_size,
248+
max_width = config.max_width,
249+
dropout = config.dropout)
249250

250251
self.prompt_rep_layer = create_projection_layer(config.hidden_size, config.dropout)
251252

@@ -331,8 +332,8 @@ def loss(self, scores, labels, prompts_embedding_mask, mask_label,
331332

332333

333334
class TokenModel(BaseModel):
334-
def __init__(self, config, encoder_from_pretrained):
335-
super(TokenModel, self).__init__(config, encoder_from_pretrained)
335+
def __init__(self, config, encoder_from_pretrained, cache_dir:Optional[Union[str, Path]] = None):
336+
super(TokenModel, self).__init__(config, encoder_from_pretrained, cache_dir=cache_dir)
336337
self.scorer = Scorer(config.hidden_size, config.dropout)
337338

338339
def forward(self,

gliner/modeling/encoder.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from .layers import LayersFuser
99
from ..utils import is_module_available, MissedPackageException
10+
from typing import Optional, Union
1011

1112
IS_LLM2VEC = is_module_available('llm2vec')
1213
IS_PEFT = is_module_available('peft')
@@ -32,14 +33,21 @@
3233
from peft import LoraConfig, get_peft_model
3334

3435
class Transformer(nn.Module):
35-
def __init__(self, model_name, config, from_pretrained=False, labels_encoder = False):
36+
def __init__(
37+
self,
38+
model_name,
39+
config,
40+
from_pretrained=False,
41+
labels_encoder = False,
42+
cache_dir:Optional[Union[str, Path]] = None
43+
):
3644
super().__init__()
3745
if labels_encoder:
3846
encoder_config = config.labels_encoder_config
3947
else:
4048
encoder_config = config.encoder_config
4149
if encoder_config is None:
42-
encoder_config = AutoConfig.from_pretrained(model_name)
50+
encoder_config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
4351
if config.vocab_size!=-1:
4452
encoder_config.vocab_size = config.vocab_size
4553

@@ -107,11 +115,11 @@ def forward(self, *args, **kwargs):
107115
return encoder_layer
108116

109117
class Encoder(nn.Module):
110-
def __init__(self, config, from_pretrained: bool = False):
118+
def __init__(self, config, from_pretrained: bool = False, cache_dir: Optional[Union[str, Path]]= None):
111119
super().__init__()
112120

113121
self.bert_layer = Transformer( #transformer_model
114-
config.model_name, config, from_pretrained,
122+
config.model_name, config, from_pretrained, cache_dir = cache_dir
115123
)
116124

117125
bert_hidden_size = self.bert_layer.model.config.hidden_size
@@ -137,11 +145,11 @@ def forward(self, *args, **kwargs) -> torch.Tensor:
137145
return token_embeddings
138146

139147
class BiEncoder(Encoder):
140-
def __init__(self, config, from_pretrained: bool = False):
148+
def __init__(self, config, from_pretrained: bool = False, cache_dir:Optional[Union[str, Path]] = None):
141149
super().__init__(config, from_pretrained)
142150
if config.labels_encoder is not None:
143151
self.labels_encoder = Transformer( #transformer_model
144-
config.labels_encoder, config, from_pretrained, True
152+
config.labels_encoder, config, from_pretrained, True, cache_dir=cache_dir
145153
)
146154
le_hidden_size = self.labels_encoder.model.config.hidden_size
147155

0 commit comments

Comments
 (0)