Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 91 additions & 18 deletions anonipy/anonymize/generators/llm_label_generator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import re
import warnings
from typing import Tuple, List

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from ...utils.package import is_installed_with
from .interface import GeneratorInterface
from ...definitions import Entity


# =====================================
# Main class
# =====================================
Expand Down Expand Up @@ -36,13 +35,15 @@ def __init__(
*args,
model_name: str = "HuggingFaceTB/SmolLM2-1.7B-Instruct",
use_gpu: bool = False,
use_quant: bool = False,
**kwargs,
):
"""Initializes the LLM label generator.

Args:
model_name: The name of the model to use.
use_gpu: Whether to use GPU or not.
use_quant: Whether to use quantization or not.

Examples:
>>> from anonipy.anonymize.generators import LLMLabelGenerator
Expand All @@ -59,8 +60,14 @@ def __init__(
)
use_gpu = False

if use_quant and not is_installed_with(["quant", "all"]):
warnings.warn(
"The use_quant=True flag requires the 'quant' extra dependencies, but they are not installed. Setting use_quant=False."
)
use_quant = False

self.model, self.tokenizer = self._prepare_model_and_tokenizer(
model_name, use_gpu
model_name, use_gpu, use_quant
)

def generate(
Expand Down Expand Up @@ -108,7 +115,7 @@ def generate(
# =================================

def _prepare_model_and_tokenizer(
self, model_name: str, use_gpu: bool
self, model_name: str, use_gpu: bool, use_quant: bool
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
"""Prepares the model and tokenizer.

Expand All @@ -125,12 +132,66 @@ def _prepare_model_and_tokenizer(
device = torch.device(
"cuda" if use_gpu and torch.cuda.is_available() else "cpu"
)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
# prepare the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
dtype = torch.float32 if device.type == "cpu" else torch.float16

model = self._load_model(model_name, device, dtype, use_quant, use_gpu)
tokenizer = self._load_tokenizer(model_name)

return model, tokenizer

def _load_model(
self,
model_name: str,
device: torch.device,
dtype: torch.dtype,
use_quant: bool,
use_gpu: bool,
) -> AutoModelForCausalLM:
"""Load the model with appropriate configuration.

Args:
model_name: The name of the model to use.
device: The device to use for the model.
dtype: The data type to use for the model.
use_quant: Whether to use quantization or not.
use_gpu: Whether to use GPU or not.

Returns:
The huggingface model.

"""
if use_quant and use_gpu:
quant_config = BitsAndBytesConfig(
load_in_8bit=True, bnb_4bit_compute_dtype=dtype
)
return AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device,
torch_dtype=dtype,
quantization_config=quant_config,
)

if use_quant:
warnings.warn(
"Quantization is only supported on GPU, but use_gpu=False. Loading model without quantization."
)

return AutoModelForCausalLM.from_pretrained(
model_name, device_map=device, torch_dtype=dtype
)

def _load_tokenizer(self, model_name: str) -> AutoTokenizer:
"""Load the tokenizer with appropriate configuration.

Args:
model_name: The name of the model to use.

Returns:
The huggingface tokenizer.
"""
return AutoTokenizer.from_pretrained(
model_name, padding_side="right", use_fast=False
)
return model, tokenizer

def _generate_response(
self, message: List[dict], temperature: float, top_p: float
Expand All @@ -152,15 +213,27 @@ def _generate_response(
message, tokenize=True, return_tensors="pt", add_generation_prompt=True
).to(self.model.device)

# generate the response
with torch.no_grad():
output_ids = self.model.generate(
input_ids,
max_new_tokens=50,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
# create attention mask (1 for all tokens)
attention_mask = torch.ones_like(input_ids)

# set pad token id if not set
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=".*the `logits` model output.*")

# generate the response
with torch.no_grad():
output_ids = self.model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=50,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)

# decode the response
response = self.tokenizer.decode(
Expand Down
34 changes: 34 additions & 0 deletions anonipy/utils/package.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from importlib.metadata import requires
from typing import Union


def is_installed_with(extra: Union[str, list[str]]) -> bool:
"""Check if anonipy was installed with specific optional dependencies.

Args:
extra: The optional dependency or list of dependencies to check.
Valid values are: 'dev', 'test', 'quant', 'all'

Returns:
True if package was installed with the specified optional dependencies,
False otherwise.

Example:
>>> from anonipy.utils.package import is_installed_with
>>> is_installed_with('dev') # check if dev dependencies are installed
>>> is_installed_with(['dev', 'test']) # check multiple dependency groups
"""
if isinstance(extra, str):
extra = [extra]

try:
package_requires = requires("anonipy") or []
installed_extras = set()

for req in package_requires:
if "extra == " in req:
installed_extras.add(req.split("extra == ")[1].strip("\"'"))

return any(e in installed_extras for e in extra)
except Exception:
return False
12 changes: 10 additions & 2 deletions docs/how-to-guides/posts/generators-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,15 @@ The [LLMLabelGenerator][anonipy.anonymize.generators.LLMLabelGenerator] is a one
from anonipy.anonymize.generators import LLMLabelGenerator
```

The `LLMLabelGenerator` currently does not require any input parameters at initialization.
The `LLMLabelGenerator` requires the following input parameters at initialization:

::: anonipy.anonymize.generators.LLMLabelGenerator.__init__
options:
show_root_heading: False
show_docstring_description: False
show_docstring_examples: False
show_docstring_returns: False
show_source: False

Let us now initialize the LLM label generator.

Expand All @@ -92,7 +100,7 @@ llm_generator = LLMLabelGenerator()
```

!!! info "Initialization warnings"
The initialization of `LLMLabelGenerator` will throw some warnings. Ignore them. These are expected due to the use of package dependencies.
The initialization of `LLMLabelGenerator` will throw some warnings. Ignore them. These are expected due to the use of package dependencies.

To use the generator, we can call the `generate` method. The `generate` method receives the following parameters:

Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ test = [
"pytest",
"pytest-cov",
]
all = ["anonipy[dev,test]"]
quant = [
"bitsandbytes",
]
all = ["anonipy[dev,test,quant]"]

[tool.setuptools.packages.find]
where = ["."]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# NLP and LLMs
spacy==3.8.2
gliner==0.2.13
gliner==0.2.16
gliner-spacy==0.0.10
transformers==4.45.2
accelerate>=0.26.0
Expand Down
62 changes: 38 additions & 24 deletions test/test_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
start_index=86,
end_index=96,
type="date",
regex=r"Date of Examination: (.*)"
regex=r"Date of Examination: (.*)",
),
# Repeated entity
Entity(
Expand All @@ -180,7 +180,7 @@
start_index=759,
end_index=769,
type="date",
regex=r"Date of Examination: (.*)"
regex=r"Date of Examination: (.*)",
),
]
TEST_MULTI_REPEATS = [
Expand Down Expand Up @@ -224,6 +224,7 @@
),
]


@pytest.fixture(autouse=True)
def suppress_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
Expand Down Expand Up @@ -270,7 +271,7 @@ def pattern_extractor():
{"SHAPE": "dddd"},
]
],
}
},
]
return PatternExtractor(labels=labels, lang=LANGUAGES.ENGLISH)

Expand Down Expand Up @@ -412,6 +413,7 @@ def test_pattern_extractor_extract_default(pattern_extractor):
assert p_entity.regex == t_entity.regex
assert p_entity.score == 1.0


def test_pattern_extractor_detect_repeats_false():
extractor = PatternExtractor(
labels=[
Expand All @@ -434,6 +436,7 @@ def test_pattern_extractor_detect_repeats_false():
assert excepted_entity.regex == entities[0].regex
assert excepted_entity.score >= 0.5


def test_pattern_extractor_detect_repeats_true():
extractor = PatternExtractor(
labels=[
Expand All @@ -455,6 +458,7 @@ def test_pattern_extractor_detect_repeats_true():
assert p_entity.regex == t_entity.regex
assert p_entity.score >= 0.5


def test_multi_extractor_init():
with pytest.raises(TypeError):
MultiExtractor()
Expand Down Expand Up @@ -568,16 +572,21 @@ def test_multi_extractor_extract_single_extractor_pattern(multi_extractor):

def test_multi_extractor_detect_repeats_false():
extractors = [
NERExtractor(labels=[
{"label": "name", "type": "string"},
]),
PatternExtractor(labels=[
{
"label": "date",
"type": "date",
"regex": r"Date of Examination: (.*)",
},
])]
NERExtractor(
labels=[
{"label": "name", "type": "string"},
]
),
PatternExtractor(
labels=[
{
"label": "date",
"type": "date",
"regex": r"Date of Examination: (.*)",
},
]
),
]
extractor = MultiExtractor(extractors)
_, joint_entities = extractor(TEST_ORIGINAL_TEXT, detect_repeats=False)
for p_entity, t_entity in zip(joint_entities, TEST_MULTI_REPEATS[:3]):
Expand All @@ -592,16 +601,21 @@ def test_multi_extractor_detect_repeats_false():

def test_multi_extractor_detect_repeats_true():
extractors = [
NERExtractor(labels=[
{"label": "name", "type": "string"},
]),
PatternExtractor(labels=[
{
"label": "date",
"type": "date",
"regex": r"Date of Examination: (.*)",
},
])]
NERExtractor(
labels=[
{"label": "name", "type": "string"},
]
),
PatternExtractor(
labels=[
{
"label": "date",
"type": "date",
"regex": r"Date of Examination: (.*)",
},
]
),
]
extractor = MultiExtractor(extractors)
_, joint_entities = extractor(TEST_ORIGINAL_TEXT, detect_repeats=True)
for p_entity, t_entity in zip(joint_entities, TEST_MULTI_REPEATS):
Expand All @@ -611,4 +625,4 @@ def test_multi_extractor_detect_repeats_true():
assert p_entity.end_index == t_entity.end_index
assert p_entity.type == t_entity.type
assert p_entity.regex == t_entity.regex
assert p_entity.score >= 0.5
assert p_entity.score >= 0.5