-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[WIP] Ner pipeline grouped_entities fixes #5970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
85d7554
31176c0
590ed80
56860f7
22d21cb
47a5e21
77f93e1
87c327e
456451a
99f7aad
188fc0b
b8d4b99
bd1c9bb
ba6dacb
9221ca6
2585ea2
47797d1
92115ee
0cf0e73
8e77d26
4b3d8eb
3bc55e4
70a4dc5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -736,15 +736,29 @@ def _test_ner_pipeline( | |
| {"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "word": "UN"}, | ||
| ], | ||
| ] | ||
|
|
||
| expected_grouped_ner_results = [ | ||
| [ | ||
| {"entity_group": "B-PER", "score": 0.9710702640669686, "word": "Consuelo Araújo Noguera"}, | ||
| {"entity_group": "B-PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"}, | ||
| {"entity_group": "B-ORG", "score": 0.8589080572128296, "word": "Farc"}, | ||
| {"entity_group": "PER", "score": 0.999369223912557, "word": "Consuelo Araújo Noguera"}, | ||
| {"entity_group": "PER", "score": 0.9997771680355072, "word": "Andrés Pastrana"}, | ||
|
||
| {"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc"}, | ||
| ], | ||
| [ | ||
| {"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo"}, | ||
| {"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"}, | ||
| ], | ||
| ] | ||
|
|
||
| expected_grouped_ner_results_w_subword = [ | ||
| [ | ||
| {"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons"}, | ||
| {"entity_group": "PER", "score": 0.9663328925768534, "word": "##uelo Araújo Noguera"}, | ||
| {"entity_group": "PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"}, | ||
| {"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc"}, | ||
| ], | ||
| [ | ||
| {"entity_group": "I-PER", "score": 0.9962901175022125, "word": "Enzo"}, | ||
| {"entity_group": "I-ORG", "score": 0.9986497163772583, "word": "UN"}, | ||
| {"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo"}, | ||
| {"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"}, | ||
| ], | ||
| ] | ||
|
|
||
|
|
@@ -771,8 +785,15 @@ def _test_ner_pipeline( | |
| for key in output_keys: | ||
| self.assertIn(key, result) | ||
|
|
||
| for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results): | ||
| self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result) | ||
| if nlp.grouped_entities: | ||
|
||
| if nlp.ignore_subwords: | ||
|
||
| for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results): | ||
| self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result) | ||
| else: | ||
| for ungrouped_input, grouped_result in zip( | ||
| ungrouped_ner_inputs, expected_grouped_ner_results_w_subword | ||
| ): | ||
| self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result) | ||
|
|
||
| @require_torch | ||
| def test_torch_ner(self): | ||
|
|
@@ -785,7 +806,14 @@ def test_torch_ner(self): | |
| def test_ner_grouped(self): | ||
| mandatory_keys = {"entity_group", "word", "score"} | ||
| for model_name in NER_FINETUNED_MODELS: | ||
| nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, grouped_entities=True) | ||
| nlp = pipeline( | ||
| task="ner", model=model_name, tokenizer=model_name, grouped_entities=True, ignore_subwords=True | ||
| ) | ||
| self._test_ner_pipeline(nlp, mandatory_keys) | ||
| for model_name in NER_FINETUNED_MODELS: | ||
| nlp = pipeline( | ||
| task="ner", model=model_name, tokenizer=model_name, grouped_entities=True, ignore_subwords=False | ||
| ) | ||
| self._test_ner_pipeline(nlp, mandatory_keys) | ||
|
|
||
| @require_tf | ||
|
|
@@ -799,7 +827,24 @@ def test_tf_ner(self): | |
| def test_tf_ner_grouped(self): | ||
| mandatory_keys = {"entity_group", "word", "score"} | ||
| for model_name in NER_FINETUNED_MODELS: | ||
| nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True) | ||
| nlp = pipeline( | ||
| task="ner", | ||
| model=model_name, | ||
| tokenizer=model_name, | ||
| framework="tf", | ||
| grouped_entities=True, | ||
| ignore_subwords=True, | ||
| ) | ||
| self._test_ner_pipeline(nlp, mandatory_keys) | ||
| for model_name in NER_FINETUNED_MODELS: | ||
| nlp = pipeline( | ||
| task="ner", | ||
| model=model_name, | ||
| tokenizer=model_name, | ||
| framework="tf", | ||
| grouped_entities=True, | ||
| ignore_subwords=False, | ||
| ) | ||
| self._test_ner_pipeline(nlp, mandatory_keys) | ||
|
|
||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.