Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 60 additions & 8 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,7 @@ def __init__(
ignore_labels=["O"],
task: str = "",
grouped_entities: bool = False,
ignore_subwords: bool = True,
):
super().__init__(
model=model,
Expand All @@ -1350,6 +1351,7 @@ def __init__(
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
self.ignore_labels = ignore_labels
self.grouped_entities = grouped_entities
self.ignore_subwords = ignore_subwords

def __call__(self, *args, **kwargs):
"""
Expand All @@ -1372,7 +1374,10 @@ def __call__(self, *args, **kwargs):
"""
inputs = self._args_parser(*args, **kwargs)
answers = []
for sentence in inputs:

for i, sentence in enumerate(inputs):
if "offset_mapping" in kwargs:
offset_mapping = kwargs["offset_mapping"][i]

# Manage correct placement of the tensors
with self.device_placement():
Expand All @@ -1382,7 +1387,14 @@ def __call__(self, *args, **kwargs):
return_attention_mask=False,
return_tensors=self.framework,
truncation=True,
return_special_tokens_mask=True,
return_offsets_mapping=self.tokenizer.is_fast,
)
if "offset_mapping" in tokens:
offset_mapping = tokens["offset_mapping"].cpu().numpy()[0]
del tokens["offset_mapping"]
special_tokens_mask = tokens["special_tokens_mask"].cpu().numpy()[0]
del tokens["special_tokens_mask"]

# Forward
if self.framework == "tf":
Expand All @@ -1399,16 +1411,26 @@ def __call__(self, *args, **kwargs):

entities = []
# Filter to labels not in `self.ignore_labels`
# Filter special_tokens
filtered_labels_idx = [
(idx, label_idx)
for idx, label_idx in enumerate(labels_idx)
if self.model.config.id2label[label_idx] not in self.ignore_labels
if (self.model.config.id2label[label_idx] not in self.ignore_labels) and not special_tokens_mask[idx]
]

for idx, label_idx in filtered_labels_idx:

if int(input_ids[idx]) == self.tokenizer.unk_token_id:
if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx]
word = sentence[start_ind:end_ind]
else:
raise Exception("Use a fast tokenizer or provide offset_mapping parameter")
else:
word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0]

entity = {
"word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
"word": word,
"score": score[idx][label_idx].item(),
"entity": self.model.config.id2label[label_idx],
"index": idx,
Expand All @@ -1435,17 +1457,31 @@ def group_sub_entities(self, entities: List[dict]) -> dict:
entities (:obj:`dict`): The entities predicted by the pipeline.
"""
# Get the first entity in the entity group
entity = entities[0]["entity"]
scores = np.mean([entity["score"] for entity in entities])
entity = entities[0]["entity"].split("-")[-1]
scores = np.nanmean([entity["score"] for entity in entities])
tokens = [entity["word"] for entity in entities]

entity_group = {
"entity_group": entity,
"score": np.mean(scores),
"word": self.tokenizer.convert_tokens_to_string(tokens),
"word": self.convert_tokens_to_string(tokens),
}
return entity_group

def is_subword_fn(self, token: str) -> bool:
if token.startswith("##"):
return True
return False

def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
if hasattr(self.tokenizer, "convert_tokens_to_string"):
# fast tokenizers dont have convert_tokens_to_string?!
return self.tokenizer.convert_tokens_to_string(tokens)
else:
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string

def group_entities(self, entities: List[dict]) -> List[dict]:
"""
Find and group together the adjacent tokens with the same entity predicted.
Expand All @@ -1457,11 +1493,18 @@ def group_entities(self, entities: List[dict]) -> List[dict]:
entity_groups = []
entity_group_disagg = []

if hasattr(self.tokenizer, "is_subword_fn"):
is_subword_fn = self.tokenizer.is_subword_fn
else:
is_subword_fn = self.is_subword_fn

if entities:
last_idx = entities[-1]["index"]

for entity in entities:

is_last_idx = entity["index"] == last_idx
is_subword = self.ignore_subwords and is_subword_fn(entity["word"])
if not entity_group_disagg:
entity_group_disagg += [entity]
if is_last_idx:
Expand All @@ -1470,10 +1513,19 @@ def group_entities(self, entities: List[dict]) -> List[dict]:

# If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
# The split is meant to account for the "B" and "I" suffixes
# Shouldn't merge if both entities are B-type
if (
entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
(
entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
and entity["entity"].split("-")[0] != "B"
)
and entity["index"] == entity_group_disagg[-1]["index"] + 1
):
) or is_subword:
# Modify subword type to be previous_type
if is_subword:
entity["entity"] = entity_group_disagg[-1]["entity"].split("-")[-1]
entity["score"] = np.nan # set ignored scores to nan and use np.nanmean

entity_group_disagg += [entity]
# Group the entities at the last entity
if is_last_idx:
Expand Down
63 changes: 54 additions & 9 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,15 +736,29 @@ def _test_ner_pipeline(
{"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "word": "UN"},
],
]

expected_grouped_ner_results = [
[
{"entity_group": "B-PER", "score": 0.9710702640669686, "word": "Consuelo Araújo Noguera"},
{"entity_group": "B-PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"},
{"entity_group": "B-ORG", "score": 0.8589080572128296, "word": "Farc"},
{"entity_group": "PER", "score": 0.999369223912557, "word": "Consuelo Araújo Noguera"},
{"entity_group": "PER", "score": 0.9997771680355072, "word": "Andrés Pastrana"},
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test for hyphenated names (ex. Juantia Gomez-Cortez) would be useful, especially given that the fast and slow tokenizers have different codepaths for reconstructing the original text. I had to implement grouping of named entities myself recently and was tripped up by that corner case.

{"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc"},
],
[
{"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo"},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"},
],
]

expected_grouped_ner_results_w_subword = [
[
{"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons"},
{"entity_group": "PER", "score": 0.9663328925768534, "word": "##uelo Araújo Noguera"},
{"entity_group": "PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"},
{"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc"},
],
[
{"entity_group": "I-PER", "score": 0.9962901175022125, "word": "Enzo"},
{"entity_group": "I-ORG", "score": 0.9986497163772583, "word": "UN"},
{"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo"},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"},
],
]

Expand All @@ -771,8 +785,15 @@ def _test_ner_pipeline(
for key in output_keys:
self.assertIn(key, result)

for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results):
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)
if nlp.grouped_entities:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conditioned so that grouped_entities=False tests won't fail because of grouped_entities=True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

if nlp.ignore_subwords:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added case for ignore_subwords=True and False

for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results):
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)
else:
for ungrouped_input, grouped_result in zip(
ungrouped_ner_inputs, expected_grouped_ner_results_w_subword
):
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)

@require_torch
def test_torch_ner(self):
Expand All @@ -785,7 +806,14 @@ def test_torch_ner(self):
def test_ner_grouped(self):
mandatory_keys = {"entity_group", "word", "score"}
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, grouped_entities=True)
nlp = pipeline(
task="ner", model=model_name, tokenizer=model_name, grouped_entities=True, ignore_subwords=True
)
self._test_ner_pipeline(nlp, mandatory_keys)
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(
task="ner", model=model_name, tokenizer=model_name, grouped_entities=True, ignore_subwords=False
)
self._test_ner_pipeline(nlp, mandatory_keys)

@require_tf
Expand All @@ -799,7 +827,24 @@ def test_tf_ner(self):
def test_tf_ner_grouped(self):
mandatory_keys = {"entity_group", "word", "score"}
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True)
nlp = pipeline(
task="ner",
model=model_name,
tokenizer=model_name,
framework="tf",
grouped_entities=True,
ignore_subwords=True,
)
self._test_ner_pipeline(nlp, mandatory_keys)
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(
task="ner",
model=model_name,
tokenizer=model_name,
framework="tf",
grouped_entities=True,
ignore_subwords=False,
)
self._test_ner_pipeline(nlp, mandatory_keys)


Expand Down