Skip to content

Commit 2e23346

Browse files
svlandegrmitsch
andauthored
Fix use_gold_ents behaviour for EntityLinker (#13400)
* fix type annotation in docs * only restore entities after loss calculation * restore entities of sample in initialization * rename overfitting function * fix EL scorer * Relax test * fix formatting * Update spacy/pipeline/entity_linker.py Co-authored-by: Raphael Mitsch <[email protected]> * rename to _ensure_ents * further rename * allow for scorer to be None --------- Co-authored-by: Raphael Mitsch <[email protected]>
1 parent 2e96797 commit 2e23346

File tree

3 files changed

+145
-27
lines changed

3 files changed

+145
-27
lines changed

spacy/pipeline/entity_linker.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from ..errors import Errors
1212
from ..kb import Candidate, KnowledgeBase
1313
from ..language import Language
14-
from ..ml import empty_kb
1514
from ..scorer import Scorer
1615
from ..tokens import Doc, Span
1716
from ..training import Example, validate_examples, validate_get_examples
@@ -105,7 +104,7 @@ def make_entity_linker(
105104
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
106105
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
107106
scorer (Optional[Callable]): The scoring method.
108-
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
107+
use_gold_ents (bool): Whether to copy entities from gold docs during training or not. If false, another
109108
component must provide entity annotations.
110109
candidates_batch_size (int): Size of batches for entity candidate generation.
111110
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
@@ -235,14 +234,44 @@ def __init__(
235234
self.cfg: Dict[str, Any] = {"overwrite": overwrite}
236235
self.distance = CosineDistance(normalize=False)
237236
self.kb = generate_empty_kb(self.vocab, entity_vector_length)
238-
self.scorer = scorer
239237
self.use_gold_ents = use_gold_ents
240238
self.candidates_batch_size = candidates_batch_size
241239
self.threshold = threshold
242240

243241
if candidates_batch_size < 1:
244242
raise ValueError(Errors.E1044)
245243

244+
def _score_with_ents_set(examples: Iterable[Example], **kwargs):
245+
# Because of how spaCy works, we can't just score immediately, because Language.evaluate
246+
# calls pipe() on the predicted docs, which won't have entities if there is no NER in the pipeline.
247+
if not scorer:
248+
return scorer
249+
if not self.use_gold_ents:
250+
return scorer(examples, **kwargs)
251+
else:
252+
examples = self._ensure_ents(examples)
253+
docs = self.pipe(
254+
(eg.predicted for eg in examples),
255+
)
256+
for eg, doc in zip(examples, docs):
257+
eg.predicted = doc
258+
return scorer(examples, **kwargs)
259+
260+
self.scorer = _score_with_ents_set
261+
262+
def _ensure_ents(self, examples: Iterable[Example]) -> Iterable[Example]:
263+
"""If use_gold_ents is true, set the gold entities to (a copy of) eg.predicted."""
264+
if not self.use_gold_ents:
265+
return examples
266+
267+
new_examples = []
268+
for eg in examples:
269+
ents, _ = eg.get_aligned_ents_and_ner()
270+
new_eg = eg.copy()
271+
new_eg.predicted.ents = ents
272+
new_examples.append(new_eg)
273+
return new_examples
274+
246275
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
247276
"""Define the KB of this pipe by providing a function that will
248277
create it using this object's vocab."""
@@ -284,11 +313,9 @@ def initialize(
284313
nO = self.kb.entity_vector_length
285314
doc_sample = []
286315
vector_sample = []
287-
for eg in islice(get_examples(), 10):
316+
examples = self._ensure_ents(islice(get_examples(), 10))
317+
for eg in examples:
288318
doc = eg.x
289-
if self.use_gold_ents:
290-
ents, _ = eg.get_aligned_ents_and_ner()
291-
doc.ents = ents
292319
doc_sample.append(doc)
293320
vector_sample.append(self.model.ops.alloc1f(nO))
294321
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
@@ -354,43 +381,31 @@ def update(
354381
losses.setdefault(self.name, 0.0)
355382
if not examples:
356383
return losses
384+
examples = self._ensure_ents(examples)
357385
validate_examples(examples, "EntityLinker.update")
358386

359-
set_dropout_rate(self.model, drop)
360-
docs = [eg.predicted for eg in examples]
361-
# save to restore later
362-
old_ents = [doc.ents for doc in docs]
363-
364-
for doc, ex in zip(docs, examples):
365-
if self.use_gold_ents:
366-
ents, _ = ex.get_aligned_ents_and_ner()
367-
doc.ents = ents
368-
else:
369-
# only keep matching ents
370-
doc.ents = ex.get_matching_ents()
371-
372387
# make sure we have something to learn from, if not, short-circuit
373388
if not self.batch_has_learnable_example(examples):
374389
return losses
375390

391+
set_dropout_rate(self.model, drop)
392+
docs = [eg.predicted for eg in examples]
376393
sentence_encodings, bp_context = self.model.begin_update(docs)
377394

378-
# now restore the ents
379-
for doc, old in zip(docs, old_ents):
380-
doc.ents = old
381-
382395
loss, d_scores = self.get_loss(
383396
sentence_encodings=sentence_encodings, examples=examples
384397
)
385398
bp_context(d_scores)
386399
if sgd is not None:
387400
self.finish_update(sgd)
388401
losses[self.name] += loss
402+
389403
return losses
390404

391405
def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
392406
validate_examples(examples, "EntityLinker.get_loss")
393407
entity_encodings = []
408+
# We assume that get_loss is called with gold ents set in the examples if need be
394409
eidx = 0 # indices in gold entities to keep
395410
keep_ents = [] # indices in sentence_encodings to keep
396411

spacy/tests/pipeline/test_entity_linker.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ def test_preserving_links_ents_2(nlp):
717717
# fmt: on
718718

719719

720-
def test_overfitting_IO():
720+
def test_overfitting_IO_gold_entities():
721721
# Simple test to try and quickly overfit the NEL component - ensuring the ML models work correctly
722722
nlp = English()
723723
vector_length = 3
@@ -744,7 +744,9 @@ def create_kb(vocab):
744744
return mykb
745745

746746
# Create the Entity Linker component and add it to the pipeline
747-
entity_linker = nlp.add_pipe("entity_linker", last=True)
747+
entity_linker = nlp.add_pipe(
748+
"entity_linker", last=True, config={"use_gold_ents": True}
749+
)
748750
assert isinstance(entity_linker, EntityLinker)
749751
entity_linker.set_kb(create_kb)
750752
assert "Q2146908" in entity_linker.vocab.strings
@@ -807,6 +809,107 @@ def create_kb(vocab):
807809
assert_equal(batch_deps_1, batch_deps_2)
808810
assert_equal(batch_deps_1, no_batch_deps)
809811

812+
eval = nlp.evaluate(train_examples)
813+
assert "nel_macro_p" in eval
814+
assert "nel_macro_r" in eval
815+
assert "nel_macro_f" in eval
816+
assert "nel_micro_p" in eval
817+
assert "nel_micro_r" in eval
818+
assert "nel_micro_f" in eval
819+
assert "nel_f_per_type" in eval
820+
assert "PERSON" in eval["nel_f_per_type"]
821+
822+
assert eval["nel_macro_f"] > 0
823+
assert eval["nel_micro_f"] > 0
824+
825+
826+
def test_overfitting_IO_with_ner():
827+
# Simple test to try and overfit the NER and NEL component in combination - ensuring the ML models work correctly
828+
nlp = English()
829+
vector_length = 3
830+
assert "Q2146908" not in nlp.vocab.strings
831+
832+
# Convert the texts to docs to make sure we have doc.ents set for the training examples
833+
train_examples = []
834+
for text, annotation in TRAIN_DATA:
835+
doc = nlp(text)
836+
train_examples.append(Example.from_dict(doc, annotation))
837+
838+
def create_kb(vocab):
839+
# create artificial KB - assign same prior weight to the two russ cochran's
840+
# Q2146908 (Russ Cochran): American golfer
841+
# Q7381115 (Russ Cochran): publisher
842+
mykb = InMemoryLookupKB(vocab, entity_vector_length=vector_length)
843+
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
844+
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
845+
mykb.add_alias(
846+
alias="Russ Cochran",
847+
entities=["Q2146908", "Q7381115"],
848+
probabilities=[0.5, 0.5],
849+
)
850+
return mykb
851+
852+
# Create the NER and EL components and add them to the pipeline
853+
ner = nlp.add_pipe("ner", first=True)
854+
entity_linker = nlp.add_pipe(
855+
"entity_linker", last=True, config={"use_gold_ents": False}
856+
)
857+
entity_linker.set_kb(create_kb)
858+
859+
train_examples = []
860+
for text, annotations in TRAIN_DATA:
861+
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
862+
for ent in annotations.get("entities"):
863+
ner.add_label(ent[2])
864+
optimizer = nlp.initialize()
865+
866+
# train the NER and NEL pipes
867+
for i in range(50):
868+
losses = {}
869+
nlp.update(train_examples, sgd=optimizer, losses=losses)
870+
assert losses["ner"] < 0.001
871+
assert losses["entity_linker"] < 0.001
872+
873+
# adding additional components that are required for the entity_linker
874+
nlp.add_pipe("sentencizer", first=True)
875+
876+
# test the trained model
877+
test_text = "Russ Cochran captured his first major title with his son as caddie."
878+
doc = nlp(test_text)
879+
ents = doc.ents
880+
assert len(ents) == 1
881+
assert ents[0].text == "Russ Cochran"
882+
assert ents[0].label_ == "PERSON"
883+
assert ents[0].kb_id_ != "NIL"
884+
885+
# TODO: below assert is still flaky - EL doesn't properly overfit quite yet
886+
# assert ents[0].kb_id_ == "Q2146908"
887+
888+
# Also test the results are still the same after IO
889+
with make_tempdir() as tmp_dir:
890+
nlp.to_disk(tmp_dir)
891+
nlp2 = util.load_model_from_path(tmp_dir)
892+
assert nlp2.pipe_names == nlp.pipe_names
893+
doc2 = nlp2(test_text)
894+
ents2 = doc2.ents
895+
assert len(ents2) == 1
896+
assert ents2[0].text == "Russ Cochran"
897+
assert ents2[0].label_ == "PERSON"
898+
assert ents2[0].kb_id_ != "NIL"
899+
900+
eval = nlp.evaluate(train_examples)
901+
assert "nel_macro_f" in eval
902+
assert "nel_micro_f" in eval
903+
assert "ents_f" in eval
904+
assert "nel_f_per_type" in eval
905+
assert "ents_per_type" in eval
906+
assert "PERSON" in eval["nel_f_per_type"]
907+
assert "PERSON" in eval["ents_per_type"]
908+
909+
assert eval["nel_macro_f"] > 0
910+
assert eval["nel_micro_f"] > 0
911+
assert eval["ents_f"] > 0
912+
810913

811914
def test_kb_serialization():
812915
# Test that the KB can be used in a pipeline with a different vocab

website/docs/api/entitylinker.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ architectures and their arguments and hyperparameters.
6161
| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ |
6262
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
6363
| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
64-
| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ |
64+
| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~bool~~ |
6565
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
6666
| `get_candidates_batch` <Tag variant="new">3.5</Tag> | Function that generates plausible candidates for a given batch of `Span` objects. Defaults to [CandidateBatchGenerator](/api/architectures#CandidateBatchGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]]~~ |
6767
| `generate_empty_kb` <Tag variant="new">3.5.1</Tag> | Function that generates an empty `KnowledgeBase` object. Defaults to [`spacy.EmptyKB.v2`](/api/architectures#EmptyKB), which generates an empty [`InMemoryLookupKB`](/api/inmemorylookupkb). ~~Callable[[Vocab, int], KnowledgeBase]~~ |

0 commit comments

Comments
 (0)