Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ Optimizations

* GITHUB#14980: Add bulk off-heap scoring for float32 vectors (Chris Hegarty)

* GITHUB#15039: Score computations are now more reliably vectorized.
(Adrien Grand, Guo Feng)

Changes in Runtime Behavior
---------------------
* GITHUB#14823: Decrease TieredMergePolicy's default number of segments per
Expand Down
10 changes: 5 additions & 5 deletions lucene/core/src/java/org/apache/lucene/search/TermScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.SlowImpactsEnum;
import org.apache.lucene.search.similarities.Similarity.BulkSimScorer;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
Expand All @@ -36,6 +37,7 @@ public final class TermScorer extends Scorer {
private final PostingsEnum postingsEnum;
private final DocIdSetIterator iterator;
private final SimScorer scorer;
private final BulkSimScorer bulkScorer;
private final NumericDocValues norms;
private final ImpactsDISI impactsDisi;
private final MaxScoreCache maxScoreCache;
Expand All @@ -49,6 +51,7 @@ public TermScorer(PostingsEnum postingsEnum, SimScorer scorer, NumericDocValues
impactsDisi = null;
this.scorer = scorer;
this.norms = norms;
this.bulkScorer = scorer.asBulkSimScorer();
}

/**
Expand All @@ -71,6 +74,7 @@ public TermScorer(
}
this.scorer = scorer;
this.norms = norms;
this.bulkScorer = scorer.asBulkSimScorer();
}

@Override
Expand Down Expand Up @@ -165,10 +169,6 @@ public void nextDocsAndScores(int upTo, Bits liveDocs, DocAndFloatFeatureBuffer
}
}

for (int i = 0; i < size; ++i) {
// Unless SimScorer#score is megamorphic, SimScorer#score should inline and (part of) score
// computations should auto-vectorize.
buffer.features[i] = scorer.score(buffer.features[i], normValues[i]);
}
bulkScorer.score(buffer.size, buffer.features, normValues, buffer.features);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.SmallFloat;

/**
Expand Down Expand Up @@ -217,8 +218,7 @@ private static class BM25Scorer extends SimScorer {
this.weight = boost * idf.getValue().floatValue();
}

@Override
public float score(float freq, long encodedNorm) {
private float doScore(float freq, float normInverse) {
// In order to guarantee monotonicity with both freq and norm without
// promoting to doubles, we rewrite freq / (freq + norm) to
// 1 - 1 / (1 + freq * 1/norm).
Expand All @@ -228,10 +228,38 @@ public float score(float freq, long encodedNorm) {
// x -> 1 + x and x -> 1 - 1/x.
// Finally we expand weight * (1 - 1 / (1 + freq * 1/norm)) to
// weight - weight / (1 + freq * 1/norm), which runs slightly faster.
float normInverse = cache[((byte) encodedNorm) & 0xFF];
return weight - weight / (1f + freq * normInverse);
}

@Override
public float score(float freq, long encodedNorm) {
float normInverse = cache[((byte) encodedNorm) & 0xFF];
return doScore(freq, normInverse);
}

@Override
public BulkSimScorer asBulkSimScorer() {
return new BulkSimScorer() {

private float[] normInverses = new float[0];

@Override
public void score(int size, float[] freqs, long[] norms, float[] scores) {
if (normInverses.length < size) {
normInverses = new float[ArrayUtil.oversize(size, Float.BYTES)];
}
for (int i = 0; i < size; ++i) {
normInverses[i] = cache[((byte) norms[i]) & 0xFF];
}

// This loop auto-vectorizes.
for (int i = 0; i < size; ++i) {
scores[i] = doScore(freqs[i], normInverses[i]);
}
}
};
}

@Override
public Explanation explain(Explanation freq, long encodedNorm) {
List<Explanation> subs = new ArrayList<>(explainConstantFactors());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.lucene.search.similarities;

import java.util.Collections;
import java.util.Objects;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.index.FieldInvertState;
Expand Down Expand Up @@ -208,6 +209,16 @@ protected SimScorer() {}
*/
public abstract float score(float freq, long norm);

/**
* Return a {@link BulkSimScorer} that produces the exact same scores as this {@link SimScorer}
* but is more efficient at bulk-computing scores.
*
* <p><b>NOTE</b>: The returned instance is not thread-safe.
*/
public BulkSimScorer asBulkSimScorer() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change of AssertingSimilarity is not included in this patch, is this intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's a miss!

return new DefaultBulkSimScorer(this);
}

/**
* Explain the score for a single document
*
Expand All @@ -223,4 +234,38 @@ public Explanation explain(Explanation freq, long norm) {
Collections.singleton(freq));
}
}

/** Specialization of {@link SimScorer} for bulk-computation of scores. */
public interface BulkSimScorer {

/**
* Bulk computation of scores. For each index {@code i} in [0, size), scores[i] is computed as
* score(freqs[i], norms[i]). The default implementation does the following:
*
* <pre class="prettyprint">
* for (int i = 0; i &lt; size; ++i) {
* scores[i] = score(freqs[i], norms[i]);
* }
* </pre>
*
* <p><b>NOTE</b>: It is legal to pass the same {@code freqs} and {@code scores} arrays.
*/
void score(int size, float[] freqs, long[] norms, float[] scores);
}

private static class DefaultBulkSimScorer implements BulkSimScorer {

private final SimScorer scorer;

DefaultBulkSimScorer(SimScorer scorer) {
this.scorer = Objects.requireNonNull(scorer);
}

@Override
public void score(int size, float[] freqs, long[] norms, float[] scores) {
for (int i = 0; i < size; ++i) {
scores[i] = scorer.score(freqs[i], norms[i]);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,26 @@ public Explanation explain(Explanation freq, long norm) {
== delegate.score(freq.getValue().floatValue(), norm);
return explanation;
}

@Override
public BulkSimScorer asBulkSimScorer() {
BulkSimScorer bulkScorer = delegate.asBulkSimScorer();
return new BulkSimScorer() {
@Override
public void score(int size, float[] freqs, long[] norms, float[] scores) {
for (int i = 0; i < size; ++i) {
assert freqs[i] > 0;
assert norms[i] != 0;
}
bulkScorer.score(size, freqs, norms, scores);
for (int i = 0; i < size; ++i) {
float score = scores[i];
assert Float.isFinite(score);
assert score >= 0;
}
}
};
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.IndriDirichletSimilarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.Similarity.BulkSimScorer;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.CheckHits;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.SmallFloat;
Expand Down Expand Up @@ -521,4 +523,41 @@ private static void doTestScoring(
}
}
}

public void testBulkScore() throws IOException {
Random random = random();
Similarity similarity = getSimilarity(random);
CollectionStatistics corpus = newCorpus(random, 1);
TermStatistics term = newTerm(random, corpus);
SimScorer scorer = similarity.scorer(random().nextFloat(5f), corpus, term);
BulkSimScorer bulkScorer = scorer.asBulkSimScorer();
int freqUpperBound =
Math.toIntExact(Math.min(term.totalTermFreq() - term.docFreq() + 1, Integer.MAX_VALUE));

float[] freqs = new float[0];
long[] norms = new long[0];
float[] scores = new float[0];

int iters = atLeast(3);
for (int iter = 0; iter < iters; ++iter) {
int size = TestUtil.nextInt(random, 0, 200);
if (size > freqs.length) {
freqs = new float[ArrayUtil.oversize(size, Float.BYTES)];
norms = new long[freqs.length];
scores = new float[freqs.length];
}
for (int i = 0; i < size; ++i) {
freqs[i] = TestUtil.nextInt(random, 1, freqUpperBound);
norms[i] = TestUtil.nextLong(random, 1, 255);
}

float[] expectedScores = new float[size];
for (int i = 0; i < size; ++i) {
expectedScores[i] = scorer.score(freqs[i], norms[i]);
}
bulkScorer.score(size, freqs, norms, scores);

assertArrayEquals(expectedScores, ArrayUtil.copyOfSubArray(scores, 0, size), 0f);
}
}
}
Loading