Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,8 @@ def __init__(
ignore_labels=["O"],
task: str = "",
grouped_entities: bool = False,
skip_special_tokens: bool = False,
ignore_subwords: bool = False,
):
super().__init__(
model=model,
Expand All @@ -1019,6 +1021,8 @@ def __init__(
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
self.ignore_labels = ignore_labels
self.grouped_entities = grouped_entities
self.skip_special_tokens = skip_special_tokens
self.ignore_subwords = ignore_subwords

def __call__(self, *args, **kwargs):
inputs = self._args_parser(*args, **kwargs)
Expand Down Expand Up @@ -1054,15 +1058,18 @@ def __call__(self, *args, **kwargs):
]

for idx, label_idx in filtered_labels_idx:
word = self.tokenizer.convert_ids_to_tokens(
[int(input_ids[idx])], skip_special_tokens=self.skip_special_tokens
)
if word:
entity = {
"word": word[0],
"score": score[idx][label_idx].item(),
"entity": self.model.config.id2label[label_idx],
"index": idx,
}

entity = {
"word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
"score": score[idx][label_idx].item(),
"entity": self.model.config.id2label[label_idx],
"index": idx,
}

entities += [entity]
entities += [entity]

# Append grouped entities
if self.grouped_entities:
Expand All @@ -1080,8 +1087,8 @@ def group_sub_entities(self, entities: List[dict]) -> dict:
Returns grouped sub entities
"""
# Get the first entity in the entity group
entity = entities[0]["entity"]
scores = np.mean([entity["score"] for entity in entities])
entity = entities[0]["entity"].split("-")[-1]
scores = np.nanmean([entity["score"] for entity in entities])
tokens = [entity["word"] for entity in entities]

entity_group = {
Expand All @@ -1096,14 +1103,21 @@ def group_entities(self, entities: List[dict]) -> List[dict]:
Returns grouped entities
"""

def is_subword(token: str) -> bool:
if token.startswith("##"):
return True
return False

entity_groups = []
entity_group_disagg = []

if entities:
last_idx = entities[-1]["index"]

for entity in entities:

is_last_idx = entity["index"] == last_idx
is_subword = self.ignore_subwords and is_subword(entity["word"])
if not entity_group_disagg:
entity_group_disagg += [entity]
if is_last_idx:
Expand All @@ -1112,10 +1126,17 @@ def group_entities(self, entities: List[dict]) -> List[dict]:

# If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
# The split is meant to account for the "B" and "I" suffixes
# Shouldn't merge if both entities are B-type
if (
entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
and entity["entity"].split("-")[0] != "B"
and entity["index"] == entity_group_disagg[-1]["index"] + 1
):
) or is_subword:
# Modify subword type to be previous_type
if is_subword:
entity["entity"] = entity_group_disagg[-1]["entity"].split("-")[-1]
entity["score"] = np.nan # Handle ignored scores as 0/nan?
# How to handle index?
entity_group_disagg += [entity]
# Group the entities at the last entity
if is_last_idx:
Expand Down