-
Notifications
You must be signed in to change notification settings - Fork 700
find related docs in single query #35727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. | ||
|
|
||
| package com.yahoo.search.searchers; | ||
|
|
||
| import com.yahoo.api.annotations.Beta; | ||
| import com.yahoo.prelude.query.AndItem; | ||
| import com.yahoo.prelude.query.Item; | ||
| import com.yahoo.prelude.query.NearestNeighborItem; | ||
| import com.yahoo.prelude.query.NotItem; | ||
| import com.yahoo.prelude.query.NullItem; | ||
| import com.yahoo.prelude.query.WordItem; | ||
| import com.yahoo.processing.request.CompoundName; | ||
| import com.yahoo.search.Query; | ||
| import com.yahoo.search.Result; | ||
| import com.yahoo.search.Searcher; | ||
| import com.yahoo.search.result.ErrorMessage; | ||
| import com.yahoo.search.result.Hit; | ||
| import com.yahoo.search.searchchain.Execution; | ||
| import com.yahoo.tensor.Tensor; | ||
|
|
||
| /** | ||
| * Finds documents related to a given document using nearest neighbor search on embeddings. | ||
| * | ||
| * <p>This searcher takes a document ID, fetches the embedding from that document, | ||
| * and performs a nearest neighbor search to find similar documents.</p> | ||
| * | ||
| * <h2>Query parameters:</h2> | ||
| * <ul> | ||
| * <li><b>relatedTo.id</b> - The ID of the source document to find related documents for (required)</li> | ||
| * <li><b>relatedTo.idField</b> - The field containing the document ID (default: "id")</li> | ||
| * <li><b>relatedTo.embeddingField</b> - The embedding field to use for NN search (required)</li> | ||
| * <li><b>relatedTo.queryTensorName</b> - The name of the query tensor to use, must match rank profile (required)</li> | ||
| * <li><b>relatedTo.summary</b> - The summary class containing the embedding (default: same as embeddingField)</li> | ||
| * <li><b>relatedTo.exploreAdditionalHits</b> - Additional candidates to explore beyond hits+offset (default: 100)</li> | ||
| * <li><b>relatedTo.excludeSource</b> - Whether to exclude the source document from results (default: true)</li> | ||
| * </ul> | ||
| * | ||
| * @author andreer | ||
| */ | ||
| @Beta | ||
| public class RelatedDocumentsByNearestNeighborSearcher extends Searcher { | ||
|
|
||
| private static final CompoundName RELATED_TO_ID = CompoundName.from("relatedTo.id"); | ||
| private static final CompoundName RELATED_TO_ID_FIELD = CompoundName.from("relatedTo.idField"); | ||
| private static final CompoundName RELATED_TO_EMBEDDING_FIELD = CompoundName.from("relatedTo.embeddingField"); | ||
| private static final CompoundName RELATED_TO_QUERY_TENSOR_NAME = CompoundName.from("relatedTo.queryTensorName"); | ||
| private static final CompoundName RELATED_TO_SUMMARY = CompoundName.from("relatedTo.summary"); | ||
| private static final CompoundName RELATED_TO_EXPLORE_ADDITIONAL_HITS = CompoundName.from("relatedTo.exploreAdditionalHits"); | ||
| private static final CompoundName RELATED_TO_EXCLUDE_SOURCE = CompoundName.from("relatedTo.excludeSource"); | ||
|
|
||
| @Override | ||
| public Result search(Query query, Execution execution) { | ||
| String sourceId = query.properties().getString(RELATED_TO_ID); | ||
| if (sourceId == null) { | ||
| return execution.search(query); | ||
| } | ||
|
|
||
| String embeddingField = query.properties().getString(RELATED_TO_EMBEDDING_FIELD); | ||
| if (embeddingField == null) { | ||
| return new Result(query, ErrorMessage.createIllegalQuery( | ||
| "relatedTo.embeddingField is required when using relatedTo.id")); | ||
| } | ||
|
|
||
| String queryTensorName = query.properties().getString(RELATED_TO_QUERY_TENSOR_NAME); | ||
| if (queryTensorName == null) { | ||
| return new Result(query, ErrorMessage.createIllegalQuery( | ||
| "relatedTo.queryTensorName is required when using relatedTo.id")); | ||
| } | ||
|
|
||
| String idField = query.properties().getString(RELATED_TO_ID_FIELD, "id"); | ||
| String summary = query.properties().getString(RELATED_TO_SUMMARY, embeddingField); | ||
| int targetHits = query.getHits() + query.getOffset(); | ||
| int exploreAdditionalHits = query.properties().getInteger(RELATED_TO_EXPLORE_ADDITIONAL_HITS, 100); | ||
| boolean excludeSource = query.properties().getBoolean(RELATED_TO_EXCLUDE_SOURCE, true); | ||
|
|
||
| Tensor embedding = fetchEmbedding(sourceId, idField, embeddingField, summary, execution, query); | ||
| if (embedding == null) { | ||
| return new Result(query, ErrorMessage.createBackendCommunicationError( | ||
| "Could not find document with " + idField + "=" + sourceId + " or it has no " + embeddingField)); | ||
| } | ||
|
|
||
| addNearestNeighborItem(embedding, embeddingField, queryTensorName, targetHits, exploreAdditionalHits, query); | ||
|
|
||
| if (excludeSource) { | ||
| excludeSourceDocument(sourceId, idField, query); | ||
| } | ||
|
|
||
| return execution.search(query); | ||
| } | ||
|
|
||
| private Tensor fetchEmbedding(String sourceId, String idField, String embeddingField, String summary, | ||
| Execution execution, Query query) { | ||
| Query fetchQuery = new Query(); | ||
| query.attachContext(fetchQuery); | ||
| fetchQuery.getPresentation().setSummary(summary); | ||
| fetchQuery.getModel().getQueryTree().setRoot(new WordItem(sourceId, idField, true)); | ||
| fetchQuery.setHits(1); | ||
| fetchQuery.getRanking().setProfile("unranked"); | ||
|
|
||
| Result result = execution.search(fetchQuery); | ||
| execution.fill(result, summary); | ||
|
|
||
| if (result.hits().size() < 1) { | ||
| return null; | ||
| } | ||
|
|
||
| Hit hit = result.hits().get(0); | ||
| Object field = hit.getField(embeddingField); | ||
| if (field instanceof Tensor tensor) { | ||
| return tensor; | ||
| } | ||
| return null; | ||
| } | ||
|
|
||
| private void addNearestNeighborItem(Tensor embedding, String embeddingField, String queryTensorName, | ||
| int targetHits, int exploreAdditionalHits, Query query) { | ||
| query.getRanking().getFeatures().put("query(" + queryTensorName + ")", embedding); | ||
|
|
||
| NearestNeighborItem nnItem = new NearestNeighborItem(embeddingField, queryTensorName); | ||
| nnItem.setAllowApproximate(true); | ||
| nnItem.setTargetNumHits(targetHits); | ||
| nnItem.setHnswExploreAdditionalHits(exploreAdditionalHits); | ||
|
|
||
| Item root = query.getModel().getQueryTree().getRoot(); | ||
| if (root instanceof NullItem || root == null) { | ||
| query.getModel().getQueryTree().setRoot(nnItem); | ||
| } else { | ||
| AndItem andItem = new AndItem(); | ||
| andItem.addItem(root); | ||
| andItem.addItem(nnItem); | ||
| query.getModel().getQueryTree().setRoot(andItem); | ||
| } | ||
| } | ||
|
|
||
| private void excludeSourceDocument(String sourceId, String idField, Query query) { | ||
| NotItem notItem = new NotItem(); | ||
| notItem.addPositiveItem(query.getModel().getQueryTree().getRoot()); | ||
| notItem.addNegativeItem(new WordItem(sourceId, idField, true)); | ||
| query.getModel().getQueryTree().setRoot(notItem); | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,211 @@ | ||
| // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. | ||
|
|
||
| package com.yahoo.search.searchers; | ||
|
|
||
| import com.yahoo.prelude.query.AndItem; | ||
| import com.yahoo.prelude.query.NearestNeighborItem; | ||
| import com.yahoo.prelude.query.NotItem; | ||
| import com.yahoo.search.Query; | ||
| import com.yahoo.search.Result; | ||
| import com.yahoo.search.Searcher; | ||
| import com.yahoo.search.result.Hit; | ||
| import com.yahoo.search.searchchain.Execution; | ||
| import com.yahoo.tensor.Tensor; | ||
| import com.yahoo.tensor.TensorType; | ||
| import org.junit.jupiter.api.Test; | ||
|
|
||
| import java.util.List; | ||
|
|
||
| import static org.junit.jupiter.api.Assertions.*; | ||
|
|
||
| /** | ||
| * @author andreer | ||
| */ | ||
| public class RelatedDocumentsByNearestNeighborSearcherTestCase { | ||
|
|
||
| private static final TensorType EMBEDDING_TYPE = TensorType.fromSpec("tensor<float>(x[4])"); | ||
| private static final Tensor TEST_EMBEDDING = Tensor.from(EMBEDDING_TYPE, "[1.0, 2.0, 3.0, 4.0]"); | ||
|
|
||
| @Test | ||
| void testNoRelatedToIdPassesThrough() { | ||
| var searcher = new RelatedDocumentsByNearestNeighborSearcher(); | ||
| var query = new Query("?query=test"); | ||
| var result = execute(searcher, query); | ||
|
|
||
| assertNull(result.hits().getError()); | ||
| } | ||
|
|
||
| @Test | ||
| void testMissingEmbeddingFieldReturnsError() { | ||
| var searcher = new RelatedDocumentsByNearestNeighborSearcher(); | ||
| var query = new Query("?relatedTo.id=doc1"); | ||
| var result = executeWithMockBackend(searcher, query); | ||
|
|
||
| assertNotNull(result.hits().getError(), "Expected error but got none"); | ||
| assertTrue(result.hits().getError().getDetailedMessage().contains("relatedTo.embeddingField is required"), | ||
| "Error message was: " + result.hits().getError().getDetailedMessage()); | ||
| } | ||
|
|
||
| @Test | ||
| void testMissingQueryTensorNameReturnsError() { | ||
| var searcher = new RelatedDocumentsByNearestNeighborSearcher(); | ||
| var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding"); | ||
| var result = executeWithMockBackend(searcher, query); | ||
|
|
||
| assertNotNull(result.hits().getError(), "Expected error but got none"); | ||
| assertTrue(result.hits().getError().getDetailedMessage().contains("relatedTo.queryTensorName is required"), | ||
| "Error message was: " + result.hits().getError().getDetailedMessage()); | ||
| } | ||
|
|
||
| @Test | ||
| void testDocumentNotFoundReturnsError() { | ||
| var searcher = new RelatedDocumentsByNearestNeighborSearcher(); | ||
| var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q"); | ||
| var result = executeWithMockBackend(searcher, query); | ||
|
|
||
| assertNotNull(result.hits().getError(), "Expected error but got none"); | ||
| assertTrue(result.hits().getError().getDetailedMessage().contains("Could not find document"), | ||
| "Error message was: " + result.hits().getError().getDetailedMessage()); | ||
| } | ||
|
Comment on lines
+61
to
+69
|
||
|
|
||
| @Test | ||
| void testNearestNeighborQueryIsCreated() { | ||
| var searcher = new RelatedDocumentsByNearestNeighborSearcher(); | ||
| var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q"); | ||
|
|
||
| var sourceHit = new Hit("source"); | ||
| sourceHit.setField("embedding", TEST_EMBEDDING); | ||
|
|
||
| var capturedQuery = new Query[1]; | ||
| var result = executeWithCapture(searcher, query, sourceHit, capturedQuery); | ||
|
|
||
| assertNull(result.hits().getError()); | ||
| assertNotNull(capturedQuery[0]); | ||
|
|
||
| var root = capturedQuery[0].getModel().getQueryTree().getRoot(); | ||
| assertInstanceOf(NotItem.class, root, "Root should be NotItem for exclusion"); | ||
| var notItem = (NotItem) root; | ||
| var positive = notItem.getPositiveItem(); | ||
| assertInstanceOf(NearestNeighborItem.class, positive, "Positive item should be NearestNeighborItem"); | ||
|
|
||
| var nnItem = (NearestNeighborItem) positive; | ||
| assertEquals("embedding", nnItem.getIndexName()); | ||
| // Default: hits=10, offset=0 -> targetHits=10, exploreAdditionalHits=100 | ||
| assertEquals(10, nnItem.getTargetNumHits()); | ||
| assertEquals(100, nnItem.getHnswExploreAdditionalHits()); | ||
| assertTrue(nnItem.getAllowApproximate()); | ||
| } | ||
|
|
||
| @Test | ||
| void testExcludeSourceCanBeDisabled() { | ||
| var searcher = new RelatedDocumentsByNearestNeighborSearcher(); | ||
| var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q&relatedTo.excludeSource=false"); | ||
|
|
||
| var sourceHit = new Hit("source"); | ||
| sourceHit.setField("embedding", TEST_EMBEDDING); | ||
|
|
||
| var capturedQuery = new Query[1]; | ||
| var result = executeWithCapture(searcher, query, sourceHit, capturedQuery); | ||
|
|
||
| assertNull(result.hits().getError()); | ||
| var root = capturedQuery[0].getModel().getQueryTree().getRoot(); | ||
| assertInstanceOf(NearestNeighborItem.class, root, "Root should be NearestNeighborItem when exclusion is disabled"); | ||
| } | ||
|
|
||
| @Test | ||
| void testTargetHitsBasedOnQueryHitsAndOffset() { | ||
| var searcher = new RelatedDocumentsByNearestNeighborSearcher(); | ||
| // hits=20, offset=5 -> targetHits=25, exploreAdditionalHits=50 | ||
| var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q&relatedTo.exploreAdditionalHits=50&hits=20&offset=5"); | ||
|
|
||
| var sourceHit = new Hit("source"); | ||
| sourceHit.setField("embedding", TEST_EMBEDDING); | ||
|
|
||
| var capturedQuery = new Query[1]; | ||
| executeWithCapture(searcher, query, sourceHit, capturedQuery); | ||
|
|
||
| var root = capturedQuery[0].getModel().getQueryTree().getRoot(); | ||
| var nnItem = findNearestNeighborItem(root); | ||
| assertNotNull(nnItem); | ||
| assertEquals(25, nnItem.getTargetNumHits()); | ||
| assertEquals(50, nnItem.getHnswExploreAdditionalHits()); | ||
| } | ||
|
|
||
| @Test | ||
| void testCombinedWithExistingQuery() { | ||
| var searcher = new RelatedDocumentsByNearestNeighborSearcher(); | ||
| var query = new Query("?query=test&relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q&relatedTo.excludeSource=false"); | ||
|
|
||
| var sourceHit = new Hit("source"); | ||
| sourceHit.setField("embedding", TEST_EMBEDDING); | ||
|
|
||
| var capturedQuery = new Query[1]; | ||
| executeWithCapture(searcher, query, sourceHit, capturedQuery); | ||
|
|
||
| var root = capturedQuery[0].getModel().getQueryTree().getRoot(); | ||
| assertInstanceOf(AndItem.class, root, "Root should be AndItem when combined with existing query"); | ||
| } | ||
|
|
||
| private NearestNeighborItem findNearestNeighborItem(com.yahoo.prelude.query.Item item) { | ||
| if (item instanceof NearestNeighborItem nn) { | ||
| return nn; | ||
| } | ||
| if (item instanceof com.yahoo.prelude.query.CompositeItem composite) { | ||
| for (var child : composite.items()) { | ||
| var found = findNearestNeighborItem(child); | ||
| if (found != null) return found; | ||
| } | ||
| } | ||
| return null; | ||
| } | ||
|
|
||
| private Result execute(Searcher searcher, Query query) { | ||
| return new Execution(searcher, Execution.Context.createContextStub()).search(query); | ||
| } | ||
|
|
||
| private Result executeWithMockBackend(Searcher searcher, Query query) { | ||
| Searcher mockBackend = new Searcher() { | ||
| @Override | ||
| public Result search(Query q, Execution execution) { | ||
| return new Result(q); | ||
| } | ||
| @Override | ||
| public void fill(Result result, String summaryClass, Execution execution) { | ||
| } | ||
| }; | ||
|
|
||
| var chain = new com.yahoo.search.searchchain.SearchChain( | ||
| new com.yahoo.component.ComponentId("test"), | ||
| List.of(searcher, mockBackend)); | ||
|
|
||
| return new Execution(chain, Execution.Context.createContextStub()).search(query); | ||
| } | ||
|
|
||
| private Result executeWithCapture(Searcher searcher, Query query, Hit sourceHit, Query[] capturedQuery) { | ||
| Searcher mockBackend = new Searcher() { | ||
| private boolean firstCall = true; | ||
| @Override | ||
| public Result search(Query q, Execution execution) { | ||
| if (firstCall) { | ||
| firstCall = false; | ||
| Result r = new Result(q); | ||
| r.hits().add(sourceHit); | ||
| return r; | ||
| } else { | ||
| capturedQuery[0] = q; | ||
| return new Result(q); | ||
| } | ||
| } | ||
| @Override | ||
| public void fill(Result result, String summaryClass, Execution execution) { | ||
| } | ||
| }; | ||
|
|
||
| var chain = new com.yahoo.search.searchchain.SearchChain( | ||
| new com.yahoo.component.ComponentId("test"), | ||
| List.of(searcher, mockBackend)); | ||
|
|
||
| return new Execution(chain, Execution.Context.createContextStub()).search(query); | ||
| } | ||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fetchEmbedding method should check for errors in the result from execution.search(fetchQuery) before accessing result.hits(). If the fetch query fails (e.g., due to backend communication errors), the method should handle the error appropriately rather than proceeding to check hit count. Consider adding a check like: if (result.hits().getError() != null) return null; after line 100.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot open a new pull request to apply changes based on this feedback