Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 113 additions & 7 deletions nlptest/modelhandler/jsl_modelhandler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import List, Union
from typing import List, Union, Dict, Tuple
import numpy as np

from .modelhandler import _ModelHandler
from ..utils.custom_types import NEROutput, NERPrediction, SequenceClassificationOutput
Expand Down Expand Up @@ -117,7 +118,107 @@ def __init__(

# in order to overwrite configs, light pipeline should be reinitialized.
self.model = LightPipeline(model)

@staticmethod
def _aggregate_words(prediction: List[Dict]) -> List[Dict]:
"""
Aggregates predictions at a word-level by taking the first token label.

Args:
predictions (List[Dict]):
predictions obtained with the pipeline object
Returns:
List[Dict]:
aggregated predictions
"""
aggregated_words = []
for i in range(0,len(prediction)):
aggregated_words.append(
{
'entity': prediction[i].result,
'score' : float(prediction[i].metadata['confidence']),
'index':i+1,
'word' : prediction[i].metadata['word'],
'start': prediction[i].begin,
'end' : prediction[i].end

}
)

return aggregated_words

@staticmethod
def _get_tag(entity_label: str) -> Tuple[str, str]:
""""
Args:
entity_label (str):
BIO style label
Returns:
Tuple[str,str]:
tag, label
"""
if entity_label.startswith("B-") or entity_label.startswith("I-"):
return entity_label.split("-")
return "I", "O"

@staticmethod
def _group_sub_entities(entities: List[dict]) -> dict:
"""
Group together the adjacent tokens with the same entity predicted.
Args:
entities (`dict`): The entities predicted by the pipeline.
"""
# Get the first entity in the entity group
entity = entities[0]["entity"].split("-")[-1]
scores = np.nanmean([entity["score"] for entity in entities])
tokens = [entity["word"] for entity in entities]

entity_group = {
"entity_group": entity,
"score": np.mean(scores),
"word": " ".join(tokens),
"start": entities[0]["start"],
"end": entities[-1]["end"],
}
return entity_group

def group_entities(self, entities: List[Dict]) -> List[Dict]:
"""
Find and group together the adjacent tokens with the same entity predicted.
Inspired and adapted from:
https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/pipelines/token_classification.py#L421

Args:
entities (List[Dict]):
The entities predicted by the pipeline.
Returns:
List[Dict]:
grouped entities
"""
entity_groups = []
entity_group_disagg = []

for entity in entities:
if not entity_group_disagg:
entity_group_disagg.append(entity)
continue

bi, tag = self._get_tag(entity["entity"])
last_bi, last_tag = self._get_tag(entity_group_disagg[-1]["entity"])

if tag == "O":
entity_groups.append(self._group_sub_entities(entity_group_disagg))
entity_group_disagg = [entity]
elif tag == last_tag and bi != "B":
entity_group_disagg.append(entity)
else:
entity_groups.append(self._group_sub_entities(entity_group_disagg))
entity_group_disagg = [entity]
if entity_group_disagg:
entity_groups.append(self._group_sub_entities(entity_group_disagg))

return entity_groups

@classmethod
def load_model(cls, path: str) -> 'NLUPipeline':
"""
Expand Down Expand Up @@ -147,17 +248,22 @@ def predict(self, text: str, *args, **kwargs) -> NEROutput:
NEROutput: A list of named entities recognized in the input text.
"""
prediction = self.model.fullAnnotate(text)[0][self.output_col]
aggregated_words = self._aggregate_words(prediction)
aggregated_predictions = self.group_entities(aggregated_words)

return NEROutput(
predictions=[
NERPrediction.from_span(
entity=ent.result,
word=ent.metadata['word'],
start=ent.begin,
end=ent.end,
score=ent.metadata[ent.result]
) for ent in prediction
entity=ent['entity_group'],
word=ent['word'],
start=ent['start'],
end=ent['end'],
score=ent['score']
) for ent in aggregated_predictions
]
)



def predict_raw(self, text: str) -> List[str]:
"""Perform predictions with SparkNLP LightPipeline on the input text.
Expand Down