diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index ddc0ba3b333c..76e88884c4cb 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -9530,6 +9530,18 @@ "public static final com.yahoo.processing.request.CompoundName dryRunKey" ] }, + "com.yahoo.search.searchers.RelatedDocumentsByNearestNeighborSearcher" : { + "superClass" : "com.yahoo.search.Searcher", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void ()", + "public com.yahoo.search.Result search(com.yahoo.search.Query, com.yahoo.search.searchchain.Execution)" + ], + "fields" : [ ] + }, "com.yahoo.search.searchers.ValidateFuzzySearcher" : { "superClass" : "com.yahoo.search.Searcher", "interfaces" : [ ], diff --git a/container-search/src/main/java/com/yahoo/search/searchers/RelatedDocumentsByNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/RelatedDocumentsByNearestNeighborSearcher.java new file mode 100644 index 000000000000..37932cd281f7 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/searchers/RelatedDocumentsByNearestNeighborSearcher.java @@ -0,0 +1,142 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.search.searchers; + +import com.yahoo.api.annotations.Beta; +import com.yahoo.prelude.query.AndItem; +import com.yahoo.prelude.query.Item; +import com.yahoo.prelude.query.NearestNeighborItem; +import com.yahoo.prelude.query.NotItem; +import com.yahoo.prelude.query.NullItem; +import com.yahoo.prelude.query.WordItem; +import com.yahoo.processing.request.CompoundName; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.Searcher; +import com.yahoo.search.result.ErrorMessage; +import com.yahoo.search.result.Hit; +import com.yahoo.search.searchchain.Execution; +import com.yahoo.tensor.Tensor; + +/** + * Finds documents related to a given document using nearest neighbor search on embeddings. + * + *

This searcher takes a document ID, fetches the embedding from that document, + * and performs a nearest neighbor search to find similar documents.

+ * + *

Query parameters:

+ * + * + * @author andreer + */ +@Beta +public class RelatedDocumentsByNearestNeighborSearcher extends Searcher { + + private static final CompoundName RELATED_TO_ID = CompoundName.from("relatedTo.id"); + private static final CompoundName RELATED_TO_ID_FIELD = CompoundName.from("relatedTo.idField"); + private static final CompoundName RELATED_TO_EMBEDDING_FIELD = CompoundName.from("relatedTo.embeddingField"); + private static final CompoundName RELATED_TO_QUERY_TENSOR_NAME = CompoundName.from("relatedTo.queryTensorName"); + private static final CompoundName RELATED_TO_SUMMARY = CompoundName.from("relatedTo.summary"); + private static final CompoundName RELATED_TO_EXPLORE_ADDITIONAL_HITS = CompoundName.from("relatedTo.exploreAdditionalHits"); + private static final CompoundName RELATED_TO_EXCLUDE_SOURCE = CompoundName.from("relatedTo.excludeSource"); + + @Override + public Result search(Query query, Execution execution) { + String sourceId = query.properties().getString(RELATED_TO_ID); + if (sourceId == null) { + return execution.search(query); + } + + String embeddingField = query.properties().getString(RELATED_TO_EMBEDDING_FIELD); + if (embeddingField == null) { + return new Result(query, ErrorMessage.createIllegalQuery( + "relatedTo.embeddingField is required when using relatedTo.id")); + } + + String queryTensorName = query.properties().getString(RELATED_TO_QUERY_TENSOR_NAME); + if (queryTensorName == null) { + return new Result(query, ErrorMessage.createIllegalQuery( + "relatedTo.queryTensorName is required when using relatedTo.id")); + } + + String idField = query.properties().getString(RELATED_TO_ID_FIELD, "id"); + String summary = query.properties().getString(RELATED_TO_SUMMARY, embeddingField); + int targetHits = query.getHits() + query.getOffset(); + int exploreAdditionalHits = query.properties().getInteger(RELATED_TO_EXPLORE_ADDITIONAL_HITS, 100); + boolean excludeSource = query.properties().getBoolean(RELATED_TO_EXCLUDE_SOURCE, true); + + Tensor embedding = fetchEmbedding(sourceId, idField, embeddingField, summary, execution, query); + if (embedding == null) { + return new Result(query, ErrorMessage.createBackendCommunicationError( + "Could not find document with " + idField + "=" + sourceId + " or it has no " + embeddingField)); + } + + addNearestNeighborItem(embedding, embeddingField, queryTensorName, targetHits, exploreAdditionalHits, query); + + if (excludeSource) { + excludeSourceDocument(sourceId, idField, query); + } + + return execution.search(query); + } + + private Tensor fetchEmbedding(String sourceId, String idField, String embeddingField, String summary, + Execution execution, Query query) { + Query fetchQuery = new Query(); + query.attachContext(fetchQuery); + fetchQuery.getPresentation().setSummary(summary); + fetchQuery.getModel().getQueryTree().setRoot(new WordItem(sourceId, idField, true)); + fetchQuery.setHits(1); + fetchQuery.getRanking().setProfile("unranked"); + + Result result = execution.search(fetchQuery); + execution.fill(result, summary); + + if (result.hits().size() < 1) { + return null; + } + + Hit hit = result.hits().get(0); + Object field = hit.getField(embeddingField); + if (field instanceof Tensor tensor) { + return tensor; + } + return null; + } + + private void addNearestNeighborItem(Tensor embedding, String embeddingField, String queryTensorName, + int targetHits, int exploreAdditionalHits, Query query) { + query.getRanking().getFeatures().put("query(" + queryTensorName + ")", embedding); + + NearestNeighborItem nnItem = new NearestNeighborItem(embeddingField, queryTensorName); + nnItem.setAllowApproximate(true); + nnItem.setTargetNumHits(targetHits); + nnItem.setHnswExploreAdditionalHits(exploreAdditionalHits); + + Item root = query.getModel().getQueryTree().getRoot(); + if (root instanceof NullItem || root == null) { + query.getModel().getQueryTree().setRoot(nnItem); + } else { + AndItem andItem = new AndItem(); + andItem.addItem(root); + andItem.addItem(nnItem); + query.getModel().getQueryTree().setRoot(andItem); + } + } + + private void excludeSourceDocument(String sourceId, String idField, Query query) { + NotItem notItem = new NotItem(); + notItem.addPositiveItem(query.getModel().getQueryTree().getRoot()); + notItem.addNegativeItem(new WordItem(sourceId, idField, true)); + query.getModel().getQueryTree().setRoot(notItem); + } + +} diff --git a/container-search/src/test/java/com/yahoo/search/searchers/RelatedDocumentsByNearestNeighborSearcherTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/RelatedDocumentsByNearestNeighborSearcherTestCase.java new file mode 100644 index 000000000000..dd330ec97cbc --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/searchers/RelatedDocumentsByNearestNeighborSearcherTestCase.java @@ -0,0 +1,211 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.search.searchers; + +import com.yahoo.prelude.query.AndItem; +import com.yahoo.prelude.query.NearestNeighborItem; +import com.yahoo.prelude.query.NotItem; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.Searcher; +import com.yahoo.search.result.Hit; +import com.yahoo.search.searchchain.Execution; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * @author andreer + */ +public class RelatedDocumentsByNearestNeighborSearcherTestCase { + + private static final TensorType EMBEDDING_TYPE = TensorType.fromSpec("tensor(x[4])"); + private static final Tensor TEST_EMBEDDING = Tensor.from(EMBEDDING_TYPE, "[1.0, 2.0, 3.0, 4.0]"); + + @Test + void testNoRelatedToIdPassesThrough() { + var searcher = new RelatedDocumentsByNearestNeighborSearcher(); + var query = new Query("?query=test"); + var result = execute(searcher, query); + + assertNull(result.hits().getError()); + } + + @Test + void testMissingEmbeddingFieldReturnsError() { + var searcher = new RelatedDocumentsByNearestNeighborSearcher(); + var query = new Query("?relatedTo.id=doc1"); + var result = executeWithMockBackend(searcher, query); + + assertNotNull(result.hits().getError(), "Expected error but got none"); + assertTrue(result.hits().getError().getDetailedMessage().contains("relatedTo.embeddingField is required"), + "Error message was: " + result.hits().getError().getDetailedMessage()); + } + + @Test + void testMissingQueryTensorNameReturnsError() { + var searcher = new RelatedDocumentsByNearestNeighborSearcher(); + var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding"); + var result = executeWithMockBackend(searcher, query); + + assertNotNull(result.hits().getError(), "Expected error but got none"); + assertTrue(result.hits().getError().getDetailedMessage().contains("relatedTo.queryTensorName is required"), + "Error message was: " + result.hits().getError().getDetailedMessage()); + } + + @Test + void testDocumentNotFoundReturnsError() { + var searcher = new RelatedDocumentsByNearestNeighborSearcher(); + var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q"); + var result = executeWithMockBackend(searcher, query); + + assertNotNull(result.hits().getError(), "Expected error but got none"); + assertTrue(result.hits().getError().getDetailedMessage().contains("Could not find document"), + "Error message was: " + result.hits().getError().getDetailedMessage()); + } + + @Test + void testNearestNeighborQueryIsCreated() { + var searcher = new RelatedDocumentsByNearestNeighborSearcher(); + var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q"); + + var sourceHit = new Hit("source"); + sourceHit.setField("embedding", TEST_EMBEDDING); + + var capturedQuery = new Query[1]; + var result = executeWithCapture(searcher, query, sourceHit, capturedQuery); + + assertNull(result.hits().getError()); + assertNotNull(capturedQuery[0]); + + var root = capturedQuery[0].getModel().getQueryTree().getRoot(); + assertInstanceOf(NotItem.class, root, "Root should be NotItem for exclusion"); + var notItem = (NotItem) root; + var positive = notItem.getPositiveItem(); + assertInstanceOf(NearestNeighborItem.class, positive, "Positive item should be NearestNeighborItem"); + + var nnItem = (NearestNeighborItem) positive; + assertEquals("embedding", nnItem.getIndexName()); + // Default: hits=10, offset=0 -> targetHits=10, exploreAdditionalHits=100 + assertEquals(10, nnItem.getTargetNumHits()); + assertEquals(100, nnItem.getHnswExploreAdditionalHits()); + assertTrue(nnItem.getAllowApproximate()); + } + + @Test + void testExcludeSourceCanBeDisabled() { + var searcher = new RelatedDocumentsByNearestNeighborSearcher(); + var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q&relatedTo.excludeSource=false"); + + var sourceHit = new Hit("source"); + sourceHit.setField("embedding", TEST_EMBEDDING); + + var capturedQuery = new Query[1]; + var result = executeWithCapture(searcher, query, sourceHit, capturedQuery); + + assertNull(result.hits().getError()); + var root = capturedQuery[0].getModel().getQueryTree().getRoot(); + assertInstanceOf(NearestNeighborItem.class, root, "Root should be NearestNeighborItem when exclusion is disabled"); + } + + @Test + void testTargetHitsBasedOnQueryHitsAndOffset() { + var searcher = new RelatedDocumentsByNearestNeighborSearcher(); + // hits=20, offset=5 -> targetHits=25, exploreAdditionalHits=50 + var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q&relatedTo.exploreAdditionalHits=50&hits=20&offset=5"); + + var sourceHit = new Hit("source"); + sourceHit.setField("embedding", TEST_EMBEDDING); + + var capturedQuery = new Query[1]; + executeWithCapture(searcher, query, sourceHit, capturedQuery); + + var root = capturedQuery[0].getModel().getQueryTree().getRoot(); + var nnItem = findNearestNeighborItem(root); + assertNotNull(nnItem); + assertEquals(25, nnItem.getTargetNumHits()); + assertEquals(50, nnItem.getHnswExploreAdditionalHits()); + } + + @Test + void testCombinedWithExistingQuery() { + var searcher = new RelatedDocumentsByNearestNeighborSearcher(); + var query = new Query("?query=test&relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q&relatedTo.excludeSource=false"); + + var sourceHit = new Hit("source"); + sourceHit.setField("embedding", TEST_EMBEDDING); + + var capturedQuery = new Query[1]; + executeWithCapture(searcher, query, sourceHit, capturedQuery); + + var root = capturedQuery[0].getModel().getQueryTree().getRoot(); + assertInstanceOf(AndItem.class, root, "Root should be AndItem when combined with existing query"); + } + + private NearestNeighborItem findNearestNeighborItem(com.yahoo.prelude.query.Item item) { + if (item instanceof NearestNeighborItem nn) { + return nn; + } + if (item instanceof com.yahoo.prelude.query.CompositeItem composite) { + for (var child : composite.items()) { + var found = findNearestNeighborItem(child); + if (found != null) return found; + } + } + return null; + } + + private Result execute(Searcher searcher, Query query) { + return new Execution(searcher, Execution.Context.createContextStub()).search(query); + } + + private Result executeWithMockBackend(Searcher searcher, Query query) { + Searcher mockBackend = new Searcher() { + @Override + public Result search(Query q, Execution execution) { + return new Result(q); + } + @Override + public void fill(Result result, String summaryClass, Execution execution) { + } + }; + + var chain = new com.yahoo.search.searchchain.SearchChain( + new com.yahoo.component.ComponentId("test"), + List.of(searcher, mockBackend)); + + return new Execution(chain, Execution.Context.createContextStub()).search(query); + } + + private Result executeWithCapture(Searcher searcher, Query query, Hit sourceHit, Query[] capturedQuery) { + Searcher mockBackend = new Searcher() { + private boolean firstCall = true; + @Override + public Result search(Query q, Execution execution) { + if (firstCall) { + firstCall = false; + Result r = new Result(q); + r.hits().add(sourceHit); + return r; + } else { + capturedQuery[0] = q; + return new Result(q); + } + } + @Override + public void fill(Result result, String summaryClass, Execution execution) { + } + }; + + var chain = new com.yahoo.search.searchchain.SearchChain( + new com.yahoo.component.ComponentId("test"), + List.of(searcher, mockBackend)); + + return new Execution(chain, Execution.Context.createContextStub()).search(query); + } + +}