Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions gliner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
words_splitter: Optional[Union[str, WordsSplitter]] = None,
data_processor: Optional[Union[SpanProcessor, TokenProcessor]] = None,
encoder_from_pretrained: bool = True,
cache_dir: Optional[Union[str, Path]] = None,
):
"""
Initialize the GLiNER model.
Expand All @@ -50,19 +51,19 @@ def __init__(
self.config = config

if tokenizer is None and data_processor is None:
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name, cache_dir=cache_dir)

if words_splitter is None and data_processor is None:
words_splitter = WordsSplitter(config.words_splitter_type)

if config.span_mode == "token_level":
if model is None:
self.model = TokenModel(config, encoder_from_pretrained)
self.model = TokenModel(config, encoder_from_pretrained, cache_dir=cache_dir)
else:
self.model = model
if data_processor is None:
if config.labels_encoder is not None:
labels_tokenizer = AutoTokenizer.from_pretrained(config.labels_encoder)
labels_tokenizer = AutoTokenizer.from_pretrained(config.labels_encoder, cache_dir=cache_dir)
self.data_processor = TokenBiEncoderProcessor(config, tokenizer, words_splitter, labels_tokenizer)
else:
self.data_processor = TokenProcessor(config, tokenizer, words_splitter)
Expand All @@ -72,12 +73,12 @@ def __init__(
self.decoder = TokenDecoder(config)
else:
if model is None:
self.model = SpanModel(config, encoder_from_pretrained)
self.model = SpanModel(config, encoder_from_pretrained, cache_dir=cache_dir)
else:
self.model = model
if data_processor is None:
if config.labels_encoder is not None:
labels_tokenizer = AutoTokenizer.from_pretrained(config.labels_encoder)
labels_tokenizer = AutoTokenizer.from_pretrained(config.labels_encoder, cache_dir=cache_dir)
self.data_processor = SpanBiEncoderProcessor(config, tokenizer, words_splitter, labels_tokenizer)
else:
self.data_processor = SpanProcessor(config, tokenizer, words_splitter)
Expand Down Expand Up @@ -778,10 +779,10 @@ def _from_pretrained(
config_file = Path(model_dir) / "gliner_config.json"

if load_tokenizer:
tokenizer = AutoTokenizer.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir, cache_dir=cache_dir)
else:
if os.path.exists(os.path.join(model_dir, "tokenizer_config.json")):
tokenizer = AutoTokenizer.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir, cache_dir=cache_dir)
else:
tokenizer = None
with open(config_file, "r") as f:
Expand All @@ -801,7 +802,7 @@ def _from_pretrained(
add_tokens = ["[FLERT]", config.ent_token, config.sep_token]

if not load_onnx_model:
gliner = cls(config, tokenizer=tokenizer, encoder_from_pretrained=False)
gliner = cls(config, tokenizer=tokenizer, encoder_from_pretrained=False, cache_dir=cache_dir)
# to be able to load GLiNER models from previous version
if (
config.class_token_index == -1 or config.vocab_size == -1
Expand Down
25 changes: 13 additions & 12 deletions gliner/modeling/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Tuple
from pathlib import Path
from typing import Optional, Tuple, Union
from abc import ABC, abstractmethod
from dataclasses import dataclass
import warnings
Expand Down Expand Up @@ -80,14 +81,14 @@ def extract_prompt_features_and_word_embeddings(config, token_embeds, input_ids,


class BaseModel(ABC, nn.Module):
def __init__(self, config, from_pretrained=False):
def __init__(self, config, from_pretrained = False, cache_dir: Optional[Union[str, Path]] = None):
super(BaseModel, self).__init__()
self.config = config

if not config.labels_encoder:
self.token_rep_layer = Encoder(config, from_pretrained)
self.token_rep_layer = Encoder(config, from_pretrained, cache_dir = cache_dir)
else:
self.token_rep_layer = BiEncoder(config, from_pretrained)
self.token_rep_layer = BiEncoder(config, from_pretrained, cache_dir=cache_dir)
if self.config.has_rnn:
self.rnn = LstmSeq2SeqEncoder(config)

Expand Down Expand Up @@ -240,12 +241,12 @@ def loss(self, x):


class SpanModel(BaseModel):
def __init__(self, config, encoder_from_pretrained):
super(SpanModel, self).__init__(config, encoder_from_pretrained)
self.span_rep_layer = SpanRepLayer(span_mode=config.span_mode,
hidden_size=config.hidden_size,
max_width=config.max_width,
dropout=config.dropout)
def __init__(self, config, encoder_from_pretrained, cache_dir: Optional[Union[str, Path]] = None):
super(SpanModel, self).__init__(config, encoder_from_pretrained, cache_dir = cache_dir)
self.span_rep_layer = SpanRepLayer(span_mode = config.span_mode,
hidden_size = config.hidden_size,
max_width = config.max_width,
dropout = config.dropout)

self.prompt_rep_layer = create_projection_layer(config.hidden_size, config.dropout)

Expand Down Expand Up @@ -331,8 +332,8 @@ def loss(self, scores, labels, prompts_embedding_mask, mask_label,


class TokenModel(BaseModel):
def __init__(self, config, encoder_from_pretrained):
super(TokenModel, self).__init__(config, encoder_from_pretrained)
def __init__(self, config, encoder_from_pretrained, cache_dir:Optional[Union[str, Path]] = None):
super(TokenModel, self).__init__(config, encoder_from_pretrained, cache_dir=cache_dir)
self.scorer = Scorer(config.hidden_size, config.dropout)

def forward(self,
Expand Down
20 changes: 14 additions & 6 deletions gliner/modeling/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .layers import LayersFuser
from ..utils import is_module_available, MissedPackageException
from typing import Optional, Union

IS_LLM2VEC = is_module_available('llm2vec')
IS_PEFT = is_module_available('peft')
Expand All @@ -32,14 +33,21 @@
from peft import LoraConfig, get_peft_model

class Transformer(nn.Module):
def __init__(self, model_name, config, from_pretrained=False, labels_encoder = False):
def __init__(
self,
model_name,
config,
from_pretrained=False,
labels_encoder = False,
cache_dir:Optional[Union[str, Path]] = None
):
super().__init__()
if labels_encoder:
encoder_config = config.labels_encoder_config
else:
encoder_config = config.encoder_config
if encoder_config is None:
encoder_config = AutoConfig.from_pretrained(model_name)
encoder_config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
if config.vocab_size!=-1:
encoder_config.vocab_size = config.vocab_size

Expand Down Expand Up @@ -107,11 +115,11 @@ def forward(self, *args, **kwargs):
return encoder_layer

class Encoder(nn.Module):
def __init__(self, config, from_pretrained: bool = False):
def __init__(self, config, from_pretrained: bool = False, cache_dir: Optional[Union[str, Path]]= None):
super().__init__()

self.bert_layer = Transformer( #transformer_model
config.model_name, config, from_pretrained,
config.model_name, config, from_pretrained, cache_dir = cache_dir
)

bert_hidden_size = self.bert_layer.model.config.hidden_size
Expand All @@ -137,11 +145,11 @@ def forward(self, *args, **kwargs) -> torch.Tensor:
return token_embeddings

class BiEncoder(Encoder):
def __init__(self, config, from_pretrained: bool = False):
def __init__(self, config, from_pretrained: bool = False, cache_dir:Optional[Union[str, Path]] = None):
super().__init__(config, from_pretrained)
if config.labels_encoder is not None:
self.labels_encoder = Transformer( #transformer_model
config.labels_encoder, config, from_pretrained, True
config.labels_encoder, config, from_pretrained, True, cache_dir=cache_dir
)
le_hidden_size = self.labels_encoder.model.config.hidden_size

Expand Down