Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions container-search/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -9530,6 +9530,18 @@
"public static final com.yahoo.processing.request.CompoundName dryRunKey"
]
},
"com.yahoo.search.searchers.RelatedDocumentsByNearestNeighborSearcher" : {
"superClass" : "com.yahoo.search.Searcher",
"interfaces" : [ ],
"attributes" : [
"public"
],
"methods" : [
"public void <init>()",
"public com.yahoo.search.Result search(com.yahoo.search.Query, com.yahoo.search.searchchain.Execution)"
],
"fields" : [ ]
},
"com.yahoo.search.searchers.ValidateFuzzySearcher" : {
"superClass" : "com.yahoo.search.Searcher",
"interfaces" : [ ],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package com.yahoo.search.searchers;

import com.yahoo.api.annotations.Beta;
import com.yahoo.prelude.query.AndItem;
import com.yahoo.prelude.query.Item;
import com.yahoo.prelude.query.NearestNeighborItem;
import com.yahoo.prelude.query.NotItem;
import com.yahoo.prelude.query.NullItem;
import com.yahoo.prelude.query.WordItem;
import com.yahoo.processing.request.CompoundName;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
import com.yahoo.search.result.ErrorMessage;
import com.yahoo.search.result.Hit;
import com.yahoo.search.searchchain.Execution;
import com.yahoo.tensor.Tensor;

/**
* Finds documents related to a given document using nearest neighbor search on embeddings.
*
* <p>This searcher takes a document ID, fetches the embedding from that document,
* and performs a nearest neighbor search to find similar documents.</p>
*
* <h2>Query parameters:</h2>
* <ul>
* <li><b>relatedTo.id</b> - The ID of the source document to find related documents for (required)</li>
* <li><b>relatedTo.idField</b> - The field containing the document ID (default: "id")</li>
* <li><b>relatedTo.embeddingField</b> - The embedding field to use for NN search (required)</li>
* <li><b>relatedTo.queryTensorName</b> - The name of the query tensor to use, must match rank profile (required)</li>
* <li><b>relatedTo.summary</b> - The summary class containing the embedding (default: same as embeddingField)</li>
* <li><b>relatedTo.exploreAdditionalHits</b> - Additional candidates to explore beyond hits+offset (default: 100)</li>
* <li><b>relatedTo.excludeSource</b> - Whether to exclude the source document from results (default: true)</li>
* </ul>
*
* @author andreer
*/
@Beta
public class RelatedDocumentsByNearestNeighborSearcher extends Searcher {

private static final CompoundName RELATED_TO_ID = CompoundName.from("relatedTo.id");
private static final CompoundName RELATED_TO_ID_FIELD = CompoundName.from("relatedTo.idField");
private static final CompoundName RELATED_TO_EMBEDDING_FIELD = CompoundName.from("relatedTo.embeddingField");
private static final CompoundName RELATED_TO_QUERY_TENSOR_NAME = CompoundName.from("relatedTo.queryTensorName");
private static final CompoundName RELATED_TO_SUMMARY = CompoundName.from("relatedTo.summary");
private static final CompoundName RELATED_TO_EXPLORE_ADDITIONAL_HITS = CompoundName.from("relatedTo.exploreAdditionalHits");
private static final CompoundName RELATED_TO_EXCLUDE_SOURCE = CompoundName.from("relatedTo.excludeSource");

@Override
public Result search(Query query, Execution execution) {
String sourceId = query.properties().getString(RELATED_TO_ID);
if (sourceId == null) {
return execution.search(query);
}

String embeddingField = query.properties().getString(RELATED_TO_EMBEDDING_FIELD);
if (embeddingField == null) {
return new Result(query, ErrorMessage.createIllegalQuery(
"relatedTo.embeddingField is required when using relatedTo.id"));
}

String queryTensorName = query.properties().getString(RELATED_TO_QUERY_TENSOR_NAME);
if (queryTensorName == null) {
return new Result(query, ErrorMessage.createIllegalQuery(
"relatedTo.queryTensorName is required when using relatedTo.id"));
}

String idField = query.properties().getString(RELATED_TO_ID_FIELD, "id");
String summary = query.properties().getString(RELATED_TO_SUMMARY, embeddingField);
int targetHits = query.getHits() + query.getOffset();
int exploreAdditionalHits = query.properties().getInteger(RELATED_TO_EXPLORE_ADDITIONAL_HITS, 100);
boolean excludeSource = query.properties().getBoolean(RELATED_TO_EXCLUDE_SOURCE, true);

Tensor embedding = fetchEmbedding(sourceId, idField, embeddingField, summary, execution, query);
if (embedding == null) {
return new Result(query, ErrorMessage.createBackendCommunicationError(
"Could not find document with " + idField + "=" + sourceId + " or it has no " + embeddingField));
}

addNearestNeighborItem(embedding, embeddingField, queryTensorName, targetHits, exploreAdditionalHits, query);

if (excludeSource) {
excludeSourceDocument(sourceId, idField, query);
}

return execution.search(query);
}

private Tensor fetchEmbedding(String sourceId, String idField, String embeddingField, String summary,
Execution execution, Query query) {
Query fetchQuery = new Query();
query.attachContext(fetchQuery);
fetchQuery.getPresentation().setSummary(summary);
fetchQuery.getModel().getQueryTree().setRoot(new WordItem(sourceId, idField, true));
fetchQuery.setHits(1);
fetchQuery.getRanking().setProfile("unranked");

Result result = execution.search(fetchQuery);
execution.fill(result, summary);

if (result.hits().size() < 1) {
return null;
}
Comment on lines +100 to +105
Copy link

Copilot AI Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fetchEmbedding method should check for errors in the result from execution.search(fetchQuery) before accessing result.hits(). If the fetch query fails (e.g., due to backend communication errors), the method should handle the error appropriately rather than proceeding to check hit count. Consider adding a check like: if (result.hits().getError() != null) return null; after line 100.

Copilot uses AI. Check for mistakes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback


Hit hit = result.hits().get(0);
Object field = hit.getField(embeddingField);
if (field instanceof Tensor tensor) {
return tensor;
}
return null;
}

private void addNearestNeighborItem(Tensor embedding, String embeddingField, String queryTensorName,
int targetHits, int exploreAdditionalHits, Query query) {
query.getRanking().getFeatures().put("query(" + queryTensorName + ")", embedding);

NearestNeighborItem nnItem = new NearestNeighborItem(embeddingField, queryTensorName);
nnItem.setAllowApproximate(true);
nnItem.setTargetNumHits(targetHits);
nnItem.setHnswExploreAdditionalHits(exploreAdditionalHits);

Item root = query.getModel().getQueryTree().getRoot();
if (root instanceof NullItem || root == null) {
query.getModel().getQueryTree().setRoot(nnItem);
} else {
AndItem andItem = new AndItem();
andItem.addItem(root);
andItem.addItem(nnItem);
query.getModel().getQueryTree().setRoot(andItem);
}
}

private void excludeSourceDocument(String sourceId, String idField, Query query) {
NotItem notItem = new NotItem();
notItem.addPositiveItem(query.getModel().getQueryTree().getRoot());
notItem.addNegativeItem(new WordItem(sourceId, idField, true));
query.getModel().getQueryTree().setRoot(notItem);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package com.yahoo.search.searchers;

import com.yahoo.prelude.query.AndItem;
import com.yahoo.prelude.query.NearestNeighborItem;
import com.yahoo.prelude.query.NotItem;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
import com.yahoo.search.result.Hit;
import com.yahoo.search.searchchain.Execution;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.jupiter.api.Test;

import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

/**
* @author andreer
*/
public class RelatedDocumentsByNearestNeighborSearcherTestCase {

private static final TensorType EMBEDDING_TYPE = TensorType.fromSpec("tensor<float>(x[4])");
private static final Tensor TEST_EMBEDDING = Tensor.from(EMBEDDING_TYPE, "[1.0, 2.0, 3.0, 4.0]");

@Test
void testNoRelatedToIdPassesThrough() {
var searcher = new RelatedDocumentsByNearestNeighborSearcher();
var query = new Query("?query=test");
var result = execute(searcher, query);

assertNull(result.hits().getError());
}

@Test
void testMissingEmbeddingFieldReturnsError() {
var searcher = new RelatedDocumentsByNearestNeighborSearcher();
var query = new Query("?relatedTo.id=doc1");
var result = executeWithMockBackend(searcher, query);

assertNotNull(result.hits().getError(), "Expected error but got none");
assertTrue(result.hits().getError().getDetailedMessage().contains("relatedTo.embeddingField is required"),
"Error message was: " + result.hits().getError().getDetailedMessage());
}

@Test
void testMissingQueryTensorNameReturnsError() {
var searcher = new RelatedDocumentsByNearestNeighborSearcher();
var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding");
var result = executeWithMockBackend(searcher, query);

assertNotNull(result.hits().getError(), "Expected error but got none");
assertTrue(result.hits().getError().getDetailedMessage().contains("relatedTo.queryTensorName is required"),
"Error message was: " + result.hits().getError().getDetailedMessage());
}

@Test
void testDocumentNotFoundReturnsError() {
var searcher = new RelatedDocumentsByNearestNeighborSearcher();
var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q");
var result = executeWithMockBackend(searcher, query);

assertNotNull(result.hits().getError(), "Expected error but got none");
assertTrue(result.hits().getError().getDetailedMessage().contains("Could not find document"),
"Error message was: " + result.hits().getError().getDetailedMessage());
}
Comment on lines +61 to +69
Copy link

Copilot AI Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no test coverage for the case where the embedding field exists but is not a Tensor type (line 109-112 in the main code). Consider adding a test case that verifies the behavior when the field contains a non-Tensor value, ensuring it returns an appropriate error message to the user.

Copilot uses AI. Check for mistakes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback


@Test
void testNearestNeighborQueryIsCreated() {
var searcher = new RelatedDocumentsByNearestNeighborSearcher();
var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q");

var sourceHit = new Hit("source");
sourceHit.setField("embedding", TEST_EMBEDDING);

var capturedQuery = new Query[1];
var result = executeWithCapture(searcher, query, sourceHit, capturedQuery);

assertNull(result.hits().getError());
assertNotNull(capturedQuery[0]);

var root = capturedQuery[0].getModel().getQueryTree().getRoot();
assertInstanceOf(NotItem.class, root, "Root should be NotItem for exclusion");
var notItem = (NotItem) root;
var positive = notItem.getPositiveItem();
assertInstanceOf(NearestNeighborItem.class, positive, "Positive item should be NearestNeighborItem");

var nnItem = (NearestNeighborItem) positive;
assertEquals("embedding", nnItem.getIndexName());
// Default: hits=10, offset=0 -> targetHits=10, exploreAdditionalHits=100
assertEquals(10, nnItem.getTargetNumHits());
assertEquals(100, nnItem.getHnswExploreAdditionalHits());
assertTrue(nnItem.getAllowApproximate());
}

@Test
void testExcludeSourceCanBeDisabled() {
var searcher = new RelatedDocumentsByNearestNeighborSearcher();
var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q&relatedTo.excludeSource=false");

var sourceHit = new Hit("source");
sourceHit.setField("embedding", TEST_EMBEDDING);

var capturedQuery = new Query[1];
var result = executeWithCapture(searcher, query, sourceHit, capturedQuery);

assertNull(result.hits().getError());
var root = capturedQuery[0].getModel().getQueryTree().getRoot();
assertInstanceOf(NearestNeighborItem.class, root, "Root should be NearestNeighborItem when exclusion is disabled");
}

@Test
void testTargetHitsBasedOnQueryHitsAndOffset() {
var searcher = new RelatedDocumentsByNearestNeighborSearcher();
// hits=20, offset=5 -> targetHits=25, exploreAdditionalHits=50
var query = new Query("?relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q&relatedTo.exploreAdditionalHits=50&hits=20&offset=5");

var sourceHit = new Hit("source");
sourceHit.setField("embedding", TEST_EMBEDDING);

var capturedQuery = new Query[1];
executeWithCapture(searcher, query, sourceHit, capturedQuery);

var root = capturedQuery[0].getModel().getQueryTree().getRoot();
var nnItem = findNearestNeighborItem(root);
assertNotNull(nnItem);
assertEquals(25, nnItem.getTargetNumHits());
assertEquals(50, nnItem.getHnswExploreAdditionalHits());
}

@Test
void testCombinedWithExistingQuery() {
var searcher = new RelatedDocumentsByNearestNeighborSearcher();
var query = new Query("?query=test&relatedTo.id=doc1&relatedTo.embeddingField=embedding&relatedTo.queryTensorName=q&relatedTo.excludeSource=false");

var sourceHit = new Hit("source");
sourceHit.setField("embedding", TEST_EMBEDDING);

var capturedQuery = new Query[1];
executeWithCapture(searcher, query, sourceHit, capturedQuery);

var root = capturedQuery[0].getModel().getQueryTree().getRoot();
assertInstanceOf(AndItem.class, root, "Root should be AndItem when combined with existing query");
}

private NearestNeighborItem findNearestNeighborItem(com.yahoo.prelude.query.Item item) {
if (item instanceof NearestNeighborItem nn) {
return nn;
}
if (item instanceof com.yahoo.prelude.query.CompositeItem composite) {
for (var child : composite.items()) {
var found = findNearestNeighborItem(child);
if (found != null) return found;
}
}
return null;
}

private Result execute(Searcher searcher, Query query) {
return new Execution(searcher, Execution.Context.createContextStub()).search(query);
}

private Result executeWithMockBackend(Searcher searcher, Query query) {
Searcher mockBackend = new Searcher() {
@Override
public Result search(Query q, Execution execution) {
return new Result(q);
}
@Override
public void fill(Result result, String summaryClass, Execution execution) {
}
};

var chain = new com.yahoo.search.searchchain.SearchChain(
new com.yahoo.component.ComponentId("test"),
List.of(searcher, mockBackend));

return new Execution(chain, Execution.Context.createContextStub()).search(query);
}

private Result executeWithCapture(Searcher searcher, Query query, Hit sourceHit, Query[] capturedQuery) {
Searcher mockBackend = new Searcher() {
private boolean firstCall = true;
@Override
public Result search(Query q, Execution execution) {
if (firstCall) {
firstCall = false;
Result r = new Result(q);
r.hits().add(sourceHit);
return r;
} else {
capturedQuery[0] = q;
return new Result(q);
}
}
@Override
public void fill(Result result, String summaryClass, Execution execution) {
}
};

var chain = new com.yahoo.search.searchchain.SearchChain(
new com.yahoo.component.ComponentId("test"),
List.of(searcher, mockBackend));

return new Execution(chain, Execution.Context.createContextStub()).search(query);
}

}
Loading