Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 75 additions & 11 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,29 @@ def __call__(self, *args, targets=None, **kwargs):
return results


class TokenClassificationArgumentHandler(ArgumentHandler):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this to check offset_mapping if provided. (does a simple batch_size check)

"""
Handles arguments for token classification.
"""

def __call__(self, *args, **kwargs):

if args is not None and len(args) > 0:
if isinstance(args, str):
inputs = [args]
else:
inputs = args
batch_size = len(inputs)

offset_mapping = kwargs.get("offset_mapping", None)
if offset_mapping:
if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple):
offset_mapping = [offset_mapping]
if len(offset_mapping) != batch_size:
raise ("offset_mapping should have the same batch size as the input")
return inputs, offset_mapping


@add_end_docstrings(
PIPELINE_INIT_ARGS,
r"""
Expand Down Expand Up @@ -1336,13 +1359,14 @@ def __init__(
ignore_labels=["O"],
task: str = "",
grouped_entities: bool = False,
ignore_subwords: bool = True,
):
super().__init__(
model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=args_parser,
args_parser=TokenClassificationArgumentHandler(),
device=device,
binary_output=binary_output,
task=task,
Expand All @@ -1357,6 +1381,7 @@ def __init__(
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
self.ignore_labels = ignore_labels
self.grouped_entities = grouped_entities
self.ignore_subwords = ignore_subwords

def __call__(self, *args, **kwargs):
"""
Expand All @@ -1377,9 +1402,11 @@ def __call__(self, *args, **kwargs):
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
corresponding token in the sentence.
"""
inputs = self._args_parser(*args, **kwargs)

inputs, offset_mappings = self._args_parser(*args, **kwargs)
answers = []
for sentence in inputs:

for i, sentence in enumerate(inputs):

# Manage correct placement of the tensors
with self.device_placement():
Expand All @@ -1389,7 +1416,18 @@ def __call__(self, *args, **kwargs):
return_attention_mask=False,
return_tensors=self.framework,
truncation=True,
return_special_tokens_mask=True,
return_offsets_mapping=self.tokenizer.is_fast,
)
if self.tokenizer.is_fast:
offset_mapping = tokens["offset_mapping"].cpu().numpy()[0]
del tokens["offset_mapping"]
elif offset_mappings:
offset_mapping = offset_mappings[i]
else:
raise Exception("To decode [UNK] tokens use a fast tokenizer or provide offset_mapping parameter")
special_tokens_mask = tokens["special_tokens_mask"].cpu().numpy()[0]
del tokens["special_tokens_mask"]

# Forward
if self.framework == "tf":
Expand All @@ -1406,21 +1444,33 @@ def __call__(self, *args, **kwargs):

entities = []
# Filter to labels not in `self.ignore_labels`
# Filter special_tokens
filtered_labels_idx = [
(idx, label_idx)
for idx, label_idx in enumerate(labels_idx)
if self.model.config.id2label[label_idx] not in self.ignore_labels
if (self.model.config.id2label[label_idx] not in self.ignore_labels) and not special_tokens_mask[idx]
]

for idx, label_idx in filtered_labels_idx:
start_ind, end_ind = offset_mapping[idx]
word_ref = sentence[start_ind:end_ind]
word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0]
is_subword = len(word_ref) != len(word)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed is_subword detection logic: by comparing length of token(##token) with the original text span mapping (Assuming subwordpieces get prefixed by something).
Incase the user wants some other logic they can first get ungrouped entities add is_subword:bool field to entities and call pipeline.group_entities themselves.


if int(input_ids[idx]) == self.tokenizer.unk_token_id:
word = word_ref
is_subword = False

entity = {
"word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
"word": word,
"score": score[idx][label_idx].item(),
"entity": self.model.config.id2label[label_idx],
"index": idx,
}

if self.grouped_entities and self.ignore_subwords:
entity["is_subword"] = is_subword

entities += [entity]

# Append grouped entities
Expand All @@ -1442,14 +1492,17 @@ def group_sub_entities(self, entities: List[dict]) -> dict:
entities (:obj:`dict`): The entities predicted by the pipeline.
"""
# Get the first entity in the entity group
entity = entities[0]["entity"]
scores = np.mean([entity["score"] for entity in entities])
entity = entities[0]["entity"].split("-")[-1]
scores = np.nanmean([entity["score"] for entity in entities])
tokens = [entity["word"] for entity in entities]

if self.tokenizer.is_fast:
word = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(tokens))
else:
word = self.tokenizer.convert_tokens_to_string(tokens)
Copy link
Contributor Author

@cceyda cceyda Oct 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed as suggested! I agree it is much cleaner this way. Umm it looks like fast tokenizers have a convert_tokens_to_string method now? 😕

entity_group = {
"entity_group": entity,
"score": np.mean(scores),
"word": self.tokenizer.convert_tokens_to_string(tokens),
"word": word,
}
return entity_group

Expand All @@ -1468,7 +1521,9 @@ def group_entities(self, entities: List[dict]) -> List[dict]:
last_idx = entities[-1]["index"]

for entity in entities:

is_last_idx = entity["index"] == last_idx
is_subword = self.ignore_subwords and entity["is_subword"]
if not entity_group_disagg:
entity_group_disagg += [entity]
if is_last_idx:
Expand All @@ -1477,10 +1532,19 @@ def group_entities(self, entities: List[dict]) -> List[dict]:

# If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
# The split is meant to account for the "B" and "I" suffixes
# Shouldn't merge if both entities are B-type
if (
entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
(
entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
and entity["entity"].split("-")[0] != "B"
)
and entity["index"] == entity_group_disagg[-1]["index"] + 1
):
) or is_subword:
# Modify subword type to be previous_type
if is_subword:
entity["entity"] = entity_group_disagg[-1]["entity"].split("-")[-1]
entity["score"] = np.nan # set ignored scores to nan and use np.nanmean

entity_group_disagg += [entity]
# Group the entities at the last entity
if is_last_idx:
Expand Down
95 changes: 70 additions & 25 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,35 +718,49 @@ def _test_ner_pipeline(

ungrouped_ner_inputs = [
[
{"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "word": "Cons"},
{"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "word": "##uelo"},
{"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "word": "Ara"},
{"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "word": "##új"},
{"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "word": "##o"},
{"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "word": "No"},
{"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "word": "##guera"},
{"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "word": "Andrés"},
{"entity": "I-PER", "index": 16, "score": 0.999740719795227, "word": "Pas"},
{"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "word": "##tran"},
{"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "word": "##a"},
{"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "word": "Far"},
{"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "word": "##c"},
{"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "is_subword": False, "word": "Cons"},
{"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "is_subword": True, "word": "##uelo"},
{"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "is_subword": False, "word": "Ara"},
{"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "is_subword": True, "word": "##új"},
{"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "is_subword": True, "word": "##o"},
{"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "is_subword": False, "word": "No"},
{"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "is_subword": True, "word": "##guera"},
{"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "is_subword": False, "word": "Andrés"},
{"entity": "I-PER", "index": 16, "score": 0.999740719795227, "is_subword": False, "word": "Pas"},
{"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "is_subword": True, "word": "##tran"},
{"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "is_subword": True, "word": "##a"},
{"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "is_subword": False, "word": "Far"},
{"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "is_subword": True, "word": "##c"},
],
[
{"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "word": "En"},
{"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "word": "##zo"},
{"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "word": "UN"},
{"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "is_subword": False, "word": "En"},
{"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "is_subword": True, "word": "##zo"},
{"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "is_subword": False, "word": "UN"},
],
]

expected_grouped_ner_results = [
[
{"entity_group": "B-PER", "score": 0.9710702640669686, "word": "Consuelo Araújo Noguera"},
{"entity_group": "B-PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"},
{"entity_group": "B-ORG", "score": 0.8589080572128296, "word": "Farc"},
{"entity_group": "PER", "score": 0.999369223912557, "word": "Consuelo Araújo Noguera"},
{"entity_group": "PER", "score": 0.9997771680355072, "word": "Andrés Pastrana"},
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test for hyphenated names (ex. Juantia Gomez-Cortez) would be useful, especially given that the fast and slow tokenizers have different codepaths for reconstructing the original text. I had to implement grouping of named entities myself recently and was tripped up by that corner case.

{"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc"},
],
[
{"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo"},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"},
],
]

expected_grouped_ner_results_w_subword = [
[
{"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons"},
{"entity_group": "PER", "score": 0.9663328925768534, "word": "##uelo Araújo Noguera"},
{"entity_group": "PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"},
{"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc"},
],
[
{"entity_group": "I-PER", "score": 0.9962901175022125, "word": "Enzo"},
{"entity_group": "I-ORG", "score": 0.9986497163772583, "word": "UN"},
{"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo"},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"},
],
]

Expand All @@ -773,8 +787,15 @@ def _test_ner_pipeline(
for key in output_keys:
self.assertIn(key, result)

for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results):
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)
if nlp.grouped_entities:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conditioned so that grouped_entities=False tests won't fail because of grouped_entities=True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

if nlp.ignore_subwords:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added case for ignore_subwords=True and False

for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results):
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)
else:
for ungrouped_input, grouped_result in zip(
ungrouped_ner_inputs, expected_grouped_ner_results_w_subword
):
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)

@require_torch
def test_torch_ner(self):
Expand All @@ -787,7 +808,14 @@ def test_torch_ner(self):
def test_ner_grouped(self):
mandatory_keys = {"entity_group", "word", "score"}
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, grouped_entities=True)
nlp = pipeline(
task="ner", model=model_name, tokenizer=model_name, grouped_entities=True, ignore_subwords=True
)
self._test_ner_pipeline(nlp, mandatory_keys)
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(
task="ner", model=model_name, tokenizer=model_name, grouped_entities=True, ignore_subwords=False
)
self._test_ner_pipeline(nlp, mandatory_keys)

@require_tf
Expand All @@ -801,7 +829,24 @@ def test_tf_ner(self):
def test_tf_ner_grouped(self):
mandatory_keys = {"entity_group", "word", "score"}
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True)
nlp = pipeline(
task="ner",
model=model_name,
tokenizer=model_name,
framework="tf",
grouped_entities=True,
ignore_subwords=True,
)
self._test_ner_pipeline(nlp, mandatory_keys)
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(
task="ner",
model=model_name,
tokenizer=model_name,
framework="tf",
grouped_entities=True,
ignore_subwords=False,
)
self._test_ner_pipeline(nlp, mandatory_keys)


Expand Down