Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
36761e7
labels decoder first changes
Ingvarstep Jun 7, 2025
cad9680
prepare prompt-based labels generation
Ingvarstep Jun 7, 2025
a8dfb5e
add labels generation API
Ingvarstep Jun 7, 2025
3e5185a
add support of training span-level labels decoder
Ingvarstep Jun 8, 2025
7508b32
add minimal span inference
Ingvarstep Jun 8, 2025
372a224
align decoder input ids with span representations
Ingvarstep Jun 14, 2025
aeaaa28
make decoding with generation more clean
Ingvarstep Jun 14, 2025
6b6468d
update labels decoding
Ingvarstep Jul 15, 2025
ab2111b
update label decoder
Ingvarstep Jul 27, 2025
be94a81
add span tokens embeddings as a condition to generate labels
Ingvarstep Jul 29, 2025
f2326c7
update full decoder context version
Ingvarstep Jul 30, 2025
325c71a
limit transformers version
Ingvarstep Aug 5, 2025
ccef141
fix data processing and encoding
Ingvarstep Aug 6, 2025
fe48f9e
add generation with constraints
Ingvarstep Aug 8, 2025
c070de9
fix cases with empty predictions
Ingvarstep Aug 8, 2025
e0043d9
Merge pull request #283 from urchade/labels_trie
Ingvarstep Aug 8, 2025
2e7d9e1
remove labels_trie.cpp
Ingvarstep Aug 8, 2025
79d8386
fix base version training
Ingvarstep Aug 8, 2025
52460be
replace generation function to transformers
Ingvarstep Aug 9, 2025
c1fbcb3
fix generation with contraints
Ingvarstep Aug 10, 2025
b7cc3c2
fix base version processor
Ingvarstep Aug 12, 2025
829b6a9
minor fixes
Ingvarstep Aug 12, 2025
d8bd1b5
fix python version of labels trie
Ingvarstep Aug 13, 2025
4c2221c
fix decoding for token-level models
Ingvarstep Aug 13, 2025
da6f692
fix gen kwargs passing
Ingvarstep Aug 14, 2025
9dd1f79
update trainer to be compatible with newer versions of transformers
Ingvarstep Aug 14, 2025
be22206
update extended readme
Ingvarstep Aug 14, 2025
979c85e
Merge branch 'main' into labels_decoder
Ingvarstep Aug 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions README_Extended.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,100 @@ Bill Gates => person
Microsoft => organization
```

## GLiNER with Decoder

A new GLiNER architecture has recently been introduced. It first extracts spans from text, then uses their latent representations to guide the generation of label names for each span.
This approach enables new use cases—such as entity linking—and expands GLiNER’s capabilities.

Because modern decoders are typically large models trained on extensive datasets, integrating them allows GLiNER to tap into their richer knowledge capacity.

```python
from gliner import GLiNER
model = GLiNER.from_pretrained('model-path')

text = "Apple was founded as Apple Computer Company on April 1, 1976, by Steve Wozniak, Steve Jobs (1955–2011) and Ronald Wayne to develop and sell Wozniak's Apple I personal computer."

labels = ["person", "other"]

model.run(texts, labels, threshold=0.3, num_gen_sequences=1)
```

**Example output:**

```json
[
[
{
"start": 21,
"end": 26,
"text": "Apple",
"label": "other",
"score": 0.6795641779899597,
"generated labels": ["Organization"]
},
{
"start": 47,
"end": 60,
"text": "April 1, 1976",
"label": "other",
"score": 0.44296327233314514,
"generated labels": ["Date"]
},
{
"start": 65,
"end": 78,
"text": "Steve Wozniak",
"label": "person",
"score": 0.9934439659118652,
"generated labels": ["Person"]
},
{
"start": 80,
"end": 90,
"text": "Steve Jobs",
"label": "person",
"score": 0.9725918769836426,
"generated labels": ["Person"]
},
{
"start": 107,
"end": 119,
"text": "Ronald Wayne",
"label": "person",
"score": 0.9964536428451538,
"generated labels": ["Person"]
}
]
]
```

---

You can also restrict the decoder to generate labels only from a predefined set:

```python
model.run(
texts, labels,
threshold=0.3,
num_gen_sequences=1,
gen_constraints=[
"organization type", "city", "organization",
"technology", "date", "person"
]
)
```

---

Two label trie implementations are available.
For a faster and more memory-efficient C++-based version, install **Cython**:

```bash
pip install cython
```

This significantly improves both performance and memory usage when working with millions of labels.

## Using FlashDeBERTa

Most GLiNER models use the DeBERTa encoder as their backbone. This architecture offers strong token classification performance and typically requires less data to achieve good results. However, a major drawback has been its slower inference speed, and until recently, there was no flash attention implementation compatible with DeBERTa's disentangled attention mechanism.
Expand Down
2 changes: 1 addition & 1 deletion configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ shuffle_types: true
random_drop: true
max_neg_type_ratio: 1
max_len: 512
freeze_token_rep: false
freeze_token_rep: false
56 changes: 56 additions & 0 deletions configs/config_decoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Model Configuration
model_name: microsoft/deberta-v3-small # Hugging Face model
labels_encoder: null
labels_decoder: HuggingFaceTB/SmolLM2-135M-Instruct #"openai-community/gpt2"
name: "span level gliner"
max_width: 12
hidden_size: 512
dropout: 0.3
fine_tune: true
subtoken_pooling: first
fuse_layers: false
post_fusion_schema: null
decoder_mode: "span"
full_decoder_context: true
span_mode: markerV1

# Training Parameters
num_steps: 300000
train_batch_size: 4
eval_every: 300
warmup_ratio: 0.05
scheduler_type: "cosine"

# loss function
loss_alpha: 0.75
loss_gamma: 0
label_smoothing: 0
loss_reduction: "mean"

# Learning Rate and weight decay Configuration
lr_encoder: 2e-5
lr_others: 3e-5
weight_decay_encoder: 0.01
weight_decay_other: 0.01

max_grad_norm: 10.0

# Directory Paths
root_dir: gliner_logs
train_data: "data/nuner_train.json" # see https://github.com/urchade/GLiNER/tree/main/data
val_data_dir: "none"
# "NER_datasets": val data from the paper can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"

# Pretrained Model Path
prev_path: null

save_total_limit: 3 #maximum amount of checkpoints to save

# Advanced Training Settings
size_sup: -1
max_types: 30
shuffle_types: true
random_drop: true
max_neg_type_ratio: 1
max_len: 512
freeze_token_rep: false
14 changes: 14 additions & 0 deletions gliner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ class GLiNERConfig(PretrainedConfig):
def __init__(self,
model_name: str = "microsoft/deberta-v3-small",
labels_encoder: str = None,
labels_decoder: str = None,
name: str = "span level gliner",
max_width: int = 12,
hidden_size: int = 512,
dropout: float = 0.4,
fine_tune: bool = True,
subtoken_pooling: str = "first",
decoder_mode: str = None, #prompt|span
full_decoder_context: bool = True,
span_mode: str = "markerV0",
post_fusion_schema: str = '', #l2l-l2t-t2t
num_post_fusion_layers: int = 1,
Expand All @@ -28,6 +31,7 @@ def __init__(self,
class_token_index: int = -1,
encoder_config: Optional[dict] = None,
labels_encoder_config: Optional[dict] = None,
labels_decoder_config: Optional[dict] = None,
ent_token = "<<ENT>>",
sep_token = "<<SEP>>",
_attn_implementation = None,
Expand All @@ -47,14 +51,24 @@ def __init__(self,
labels_encoder_config = CONFIG_MAPPING[labels_encoder_config["model_type"]](**labels_encoder_config)
self.labels_encoder_config = labels_encoder_config

if isinstance(labels_decoder_config, dict):
labels_decoder_config["model_type"] = (labels_decoder_config["model_type"]
if "model_type" in labels_decoder_config
else "gpt")
labels_decoder_config = CONFIG_MAPPING[labels_decoder_config["model_type"]](**labels_decoder_config)
self.labels_decoder_config = labels_decoder_config

self.model_name = model_name
self.labels_encoder = labels_encoder
self.labels_decoder = labels_decoder
self.name = name
self.max_width = max_width
self.hidden_size = hidden_size
self.dropout = dropout
self.fine_tune = fine_tune
self.subtoken_pooling = subtoken_pooling
self.decoder_mode = decoder_mode
self.full_decoder_context = full_decoder_context
self.span_mode = span_mode
self.post_fusion_schema = post_fusion_schema
self.num_post_fusion_layers = num_post_fusion_layers
Expand Down
Loading