diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java index 6acc85fcee66..1d46baba4ff1 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -94,8 +94,13 @@ public void nextCandidate() { @Override public KnnSearchStrategy getSearchStrategy() { KnnSearchStrategy delegateStrategy = delegate.getSearchStrategy(); - assert delegateStrategy instanceof KnnSearchStrategy.Hnsw; - return new KnnSearchStrategy.Patience( - this, ((KnnSearchStrategy.Hnsw) delegateStrategy).filteredSearchThreshold()); + if (delegateStrategy instanceof KnnSearchStrategy.Hnsw hnswStrategy) { + return new KnnSearchStrategy.Patience(this, hnswStrategy.filteredSearchThreshold()); + } else if (delegateStrategy instanceof KnnSearchStrategy.Seeded seededStrategy) { + if (seededStrategy.originalStrategy() instanceof KnnSearchStrategy.Hnsw hnswStrategy) { + return new KnnSearchStrategy.Patience(this, hnswStrategy.filteredSearchThreshold()); + } + } + return delegateStrategy; } } diff --git a/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java index ac2d401cbbea..ad7332e4ca8a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java @@ -43,8 +43,7 @@ public class PatienceKnnVectorQuery extends AbstractKnnVectorQuery { private final int patience; private final double saturationThreshold; - - final AbstractKnnVectorQuery delegate; + private AbstractKnnVectorQuery delegate; /** * Construct a new PatienceKnnVectorQuery instance for a float vector field @@ -234,4 +233,18 @@ public KnnCollector newCollector( patience); } } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (delegate instanceof SeededKnnVectorQuery seededKnnVectorQuery) { + // this is required because SeededKnnVectorQuery now requires its own rewriting logic (to + // create the seed Weight) + delegate = + new SeededKnnVectorQuery( + seededKnnVectorQuery.delegate, + seededKnnVectorQuery.seed, + seededKnnVectorQuery.createSeedWeight(indexSearcher)); + } + return super.rewrite(indexSearcher); + } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java index 60979c27c029..b361d8aaa273 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java @@ -30,13 +30,27 @@ import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.store.Directory; import org.apache.lucene.util.TestVectorUtil; +import org.junit.Before; public class TestPatienceByteVectorQuery extends BaseKnnVectorQueryTestCase { + private boolean wrapSeeded; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + wrapSeeded = random().nextBoolean(); + } + @Override PatienceKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { - return PatienceKnnVectorQuery.fromByteQuery( - new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter)); + KnnByteVectorQuery knnQuery = + new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter); + return wrapSeeded + ? PatienceKnnVectorQuery.fromSeededQuery( + SeededKnnVectorQuery.fromByteQuery(knnQuery, new MatchNoDocsQuery())) + : PatienceKnnVectorQuery.fromByteQuery(knnQuery); } @Override @@ -80,7 +94,13 @@ public void testToString() throws IOException { IndexReader reader = DirectoryReader.open(indexStore)) { AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10); assertEquals( - "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnByteVectorQuery:field[0,...][10]}", + "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=" + + (wrapSeeded + ? "SeededKnnVectorQuery{seed=MatchNoDocsQuery(\"\"), seedWeight=null, delegate=" + : "") + + "KnnByteVectorQuery:field[0,...][10]" + + (wrapSeeded ? "}" : "") + + "}", query.toString("ignored")); assertDocScoreQueryToString(query.rewrite(newSearcher(reader))); @@ -89,7 +109,13 @@ public void testToString() throws IOException { Query filter = new TermQuery(new Term("id", "text")); query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, filter); assertEquals( - "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnByteVectorQuery:field[0,...][10][id:text]}", + "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=" + + (wrapSeeded + ? "SeededKnnVectorQuery{seed=MatchNoDocsQuery(\"\"), seedWeight=null, delegate=" + : "") + + "KnnByteVectorQuery:field[0,...][10][id:text]" + + (wrapSeeded ? "}" : "") + + "}", query.toString("ignored")); } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java index 21be86b30cac..75b60bfd36b4 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java @@ -28,13 +28,26 @@ import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.store.Directory; import org.apache.lucene.util.TestVectorUtil; +import org.junit.Before; public class TestPatienceFloatVectorQuery extends BaseKnnVectorQueryTestCase { + private boolean wrapSeeded; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + wrapSeeded = random().nextBoolean(); + } + @Override PatienceKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { - return PatienceKnnVectorQuery.fromFloatQuery( - new KnnFloatVectorQuery(field, query, k, queryFilter)); + KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(field, query, k, queryFilter); + return wrapSeeded + ? PatienceKnnVectorQuery.fromSeededQuery( + SeededKnnVectorQuery.fromFloatQuery(knnQuery, new MatchNoDocsQuery())) + : PatienceKnnVectorQuery.fromFloatQuery(knnQuery); } @Override @@ -71,7 +84,13 @@ public void testToString() throws IOException { IndexReader reader = DirectoryReader.open(indexStore)) { AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10); assertEquals( - "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnFloatVectorQuery:field[0.0,...][10]}", + "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=" + + (wrapSeeded + ? "SeededKnnVectorQuery{seed=MatchNoDocsQuery(\"\"), seedWeight=null, delegate=" + : "") + + "KnnFloatVectorQuery:field[0.0,...][10]" + + (wrapSeeded ? "}" : "") + + "}", query.toString("ignored")); assertDocScoreQueryToString(query.rewrite(newSearcher(reader))); @@ -80,7 +99,13 @@ public void testToString() throws IOException { Query filter = new TermQuery(new Term("id", "text")); query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, filter); assertEquals( - "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnFloatVectorQuery:field[0.0,...][10][id:text]}", + "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=" + + (wrapSeeded + ? "SeededKnnVectorQuery{seed=MatchNoDocsQuery(\"\"), seedWeight=null, delegate=" + : "") + + "KnnFloatVectorQuery:field[0.0,...][10][id:text]" + + (wrapSeeded ? "}" : "") + + "}", query.toString("ignored")); } }