Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,12 @@ public String toString() {
}

public StemMode toStemMode() {
switch(this) {
case SHORTEST: return StemMode.SHORTEST;
case MULTIPLE: return StemMode.ALL;
case BEST : return StemMode.BEST;
case NONE: return StemMode.NONE;
default: throw new IllegalStateException("Inconvertible stem mode " + this);
}
return switch (this) {
case SHORTEST -> StemMode.SHORTEST;
case MULTIPLE -> StemMode.ALL;
case BEST -> StemMode.BEST;
case NONE -> StemMode.NONE;
};
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ private void addAnnotationSimple(SimpleIndexingAnnotations simple, String input,
TermOccurrences termOccurrences) {
if (!token.isSpecialToken()) {
if (token.getNumComponents() > 0) {
for (int i = 0; i < token.getNumComponents(); ++i) {
for (int i = 0; i < token.getNumComponents(); i++) {
addAnnotationSimple(simple, input, token.getComponent(i), termOccurrences);
}
return;
Expand All @@ -155,7 +155,7 @@ private void addAnnotationSimple(SimpleIndexingAnnotations simple, String input,
if (from + length > input.length()) return;

if (config.getStemMode() == StemMode.ALL) {
addAllStemsSimple(simple, input, token, from, length, termOccurrences);
addAllStemsSimple(simple, token, from, length, termOccurrences);
} else {
String term = token.getTokenString();
if (term == null || term.trim().isEmpty()) return;
Expand All @@ -167,7 +167,7 @@ private void addAnnotationSimple(SimpleIndexingAnnotations simple, String input,
}
}

private void addAllStemsSimple(SimpleIndexingAnnotations simple, String input, Token token,
private void addAllStemsSimple(SimpleIndexingAnnotations simple, Token token,
int from, int length, TermOccurrences termOccurrences) {
String indexableOriginal = config.getLowercase() ? toLowerCase(token.getOrig()) : token.getOrig();
String term = token.getTokenString();
Expand All @@ -186,7 +186,7 @@ private void addAllStemsSimple(SimpleIndexingAnnotations simple, String input, T

for (int i = 0; i < token.getNumStems(); i++) {
String stem = token.getStem(i);
if (stem.equals(indexableOriginal) || (term != null && stem.equals(term))) continue;
if (stem.equals(indexableOriginal) || stem.equals(term)) continue;
if (stem.length() > config.getMaxTokenLength()) continue;
if (!termOccurrences.termCountBelowLimit(stem)) continue;
simple.add(from, length, stem);
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ record CacheKey(String embedderId, String text){}
builder.cell(rawEmbedding[i], i);
}
var embedding = builder.build();
return normalize ? EmbeddingNormalizer.normalize(embedding, tensorType) : embedding;
return normalize ? embedding.l2Normalize(embedding.type().dimensions().getFirst().name()) : embedding;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.modelintegration.utils.ModelPathHelper;
import ai.vespa.modelintegration.utils.OnnxExternalDataResolver;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
Expand All @@ -24,7 +23,10 @@

import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST;

@Beta
/**
* A general embedder for HuggingFace models.
* This will also quantize to the target embedding type.
*/
public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {

private static final Logger log = Logger.getLogger(HuggingFaceEmbedder.class.getName());
Expand Down Expand Up @@ -155,26 +157,28 @@ public Tensor embed(String text, Context context, TensorType targetType) {
throw new IllegalArgumentException("Error in embedding to type '" + targetType + "': dimension should be indexed.");
}
var embeddingResult = lookupOrEvaluate(context, prependInstruction(text, context));
IndexedTensor tokenEmbeddings = embeddingResult.output;
if (targetType.valueType() == TensorType.Value.INT8) {
return binaryQuantization(embeddingResult, targetType);
if (targetType.valueType() == TensorType.Value.INT8 && sizeIndicatesBitPacking(targetType, embeddingResult)) {
return binaryQuantize(embeddingResult, targetType);
} else if (targetType.valueType() == TensorType.Value.INT8) {
return byteQuantize(embeddingResult, targetType);
} else {
Tensor result = analysis.poolingStrategy.toSentenceEmbedding(targetType, tokenEmbeddings, embeddingResult.attentionMask);
return normalize ? EmbeddingNormalizer.normalize(result, targetType) : result;
return poolAndNormalize(embeddingResult, targetType, targetType.dimensions().get(0).size().get());
}
}

private boolean sizeIndicatesBitPacking(TensorType targetType, HuggingFaceEmbedder.HFEmbeddingResult embeddingResult) {
return targetType.dimensions().get(0).size().get()
<= embeddingResult.output().shape()[embeddingResult.output().shape().length - 1] / 8;
}

String prependInstruction(String text, Context context) {
if (prependQuery != null && !prependQuery.isEmpty() && context.getDestination().startsWith("query")) {
if (prependQuery != null && !prependQuery.isEmpty() && context.getDestination().startsWith("query"))
return prependQuery + " " + text;
}
if (prependDocument != null && !prependDocument.isEmpty()){
if (prependDocument != null && !prependDocument.isEmpty())
return prependDocument + " " + text;
}
return text;
}


private HuggingFaceEmbedder.HFEmbeddingResult lookupOrEvaluate(Context context, String text) {
var key = new HFEmbedderCacheKey(context.getEmbedderId(), text);
return context.computeCachedValueIfAbsent(key, () -> evaluate(context, text));
Expand Down Expand Up @@ -211,24 +215,37 @@ private HuggingFaceEmbedder.HFEmbeddingResult evaluate(Context context, String t
return new HFEmbeddingResult(tokenEmbeddings, attentionMask, context.getEmbedderId());
}

private Tensor binaryQuantization(HuggingFaceEmbedder.HFEmbeddingResult embeddingResult, TensorType targetType) {
long outputDimensions = embeddingResult.output().shape()[2];
private Tensor binaryQuantize(HuggingFaceEmbedder.HFEmbeddingResult embeddingResult, TensorType targetType) {
long targetUnpackagedDimensions = 8 * targetType.dimensions().get(0).size().get();
Tensor packedResult = Tensors.packBits(poolAndNormalize(embeddingResult, targetType, targetUnpackagedDimensions));
if ( ! packedResult.type().equals(targetType))
throw new IllegalStateException("Expected pack_bits to produce " + targetType + ", but got " + packedResult.type());
return packedResult;
}

private Tensor byteQuantize(HuggingFaceEmbedder.HFEmbeddingResult embeddingResult, TensorType targetType) {
long targetDimensions = targetType.dimensions().get(0).size().get();
//🪆 flexibility - packing only the first 8*targetDimension float values from the model output
long targetUnpackagedDimensions = 8 * targetDimensions;
if (targetUnpackagedDimensions > outputDimensions) {
throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDimensions + " int8's");
var result = (IndexedTensor)poolAndNormalize(embeddingResult, targetType, targetDimensions);
IndexedTensor.Builder builder = IndexedTensor.Builder.of(targetType);
for (int i = 0; i < targetDimensions; i++) {
double value = result.get(i);
int quantized = (int) Math.round(value * 127.0); // scale to byte
quantized = Math.max(-128, Math.min(127, quantized)); // clamp
builder.cell((byte) quantized, i);
}
// pool and normalize using float version before binary quantization
return builder.build();
}

private Tensor poolAndNormalize(HuggingFaceEmbedder.HFEmbeddingResult embeddingResult, TensorType targetType, long targetDimensions) {
long outputDimensions = embeddingResult.output().shape()[embeddingResult.output().shape().length - 1];
if (targetDimensions > outputDimensions)
throw new IllegalArgumentException("Cannot quantize " + outputDimensions + " dimensions into " + targetType);

TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT).
indexed(targetType.indexedSubtype().dimensions().get(0).name(), targetUnpackagedDimensions)
indexed(targetType.indexedSubtype().dimensions().get(0).name(), targetDimensions)
.build();
Tensor result = analysis.poolingStrategy().toSentenceEmbedding(poolingType, embeddingResult.output(), embeddingResult.attentionMask());
result = normalize ? EmbeddingNormalizer.normalize(result, poolingType) : result;
Tensor packedResult = Tensors.packBits(result);
if ( ! packedResult.type().equals(targetType))
throw new IllegalStateException("Expected pack_bits to produce " + targetType + ", but got " + packedResult.type());
return packedResult;
return normalize ? result.l2Normalize(result.type().dimensions().getFirst().name()) : result;
}

private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.embedding;


import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.modelintegration.utils.ModelPathHelper;
import com.yahoo.config.ModelReference;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.process.Embedder;
import ai.vespa.modelintegration.evaluator.config.OnnxEvaluatorConfig;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.UnpackBitsNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.Tensors;
import org.junit.Test;

import java.nio.file.Path;
Expand All @@ -35,30 +29,6 @@ public class HuggingFaceEmbedderTest {
static HuggingFaceEmbedder embedder = getEmbedder();
static HuggingFaceEmbedder normalizedEmbedder = getNormalizedEmbedder();

@Test
public void testBinarization() {
assertPackRight("tensor(x[8]):[0,0,0,0,0,0,0,0]", "tensor<int8>(x[1]):[0]");
assertPackRight("tensor(x[8]):[1,1,1,1,1,1,1,1]", "tensor<int8>(x[1]):[-1]");
assertPackRight("tensor(x[16]):[0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1]", "tensor<int8>(x[2]):[0, -1]");

assertPackRight("tensor(x[8]):[0,1,0,1,0,1,0,1]", "tensor<int8>(x[1]):[85]");
assertPackRight("tensor(x[8]):[1,0,1,0,1,0,1,0]", "tensor<int8>(x[1]):[-86]");
assertPackRight("tensor(x[16]):[0,1,0,1,0,1,0,1,1,0,1,0,1,0,1,0]", "tensor<int8>(x[2]):[85, -86]");

assertPackRight("tensor(x[8]):[1,1,1,1,0,0,0,0]", "tensor<int8>(x[1]):[-16]");
assertPackRight("tensor(x[8]):[0,0,0,0,1,1,1,1]", "tensor<int8>(x[1]):[15]");
assertPackRight("tensor(x[16]):[1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1]", "tensor<int8>(x[2]):[-16, 15]");
}

private void assertPackRight(String input, String expected) {
Tensor inputTensor = Tensor.from(input);
Tensor result = Tensors.packBits(inputTensor);
assertEquals(expected, result.toString());
// Verify that the unpack_bits ranking feature produce compatible output
Tensor unpacked = expandBitTensor(result);
assertEquals(inputTensor.toString(), unpacked.toString());
}

@Test
public void testCaching() {
var context = new Embedder.Context("schema.indexing");
Expand Down Expand Up @@ -91,6 +61,7 @@ public void testCaching() {
embedder.embed(input, copyContext,TensorType.fromSpec("tensor<int8>(x[2])"));
assertNotEquals(modelOuput, copyContext.getCachedValue(key));
}

@Test
public void testEmbedder() {
var context = new Embedder.Context("schema.indexing");
Expand All @@ -110,10 +81,14 @@ public void testEmbedder() {
binarizedResult = embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[48])")));
assertTrue(binarizedResult.toAbbreviatedString().startsWith("tensor<int8>(x[48]):[119, 44"));

// Test byte quantization (1 float per byte, not binary packing)
Tensor byteQuantizedResult = embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[49])")));
assertEquals(49, byteQuantizedResult.size());

assertThrows(IllegalArgumentException.class, () -> {
// throws because the target tensor type is not compatible with the model output
//49*8 > 384
embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[49])")));
// throws because the target tensor dimension exceeds model output dimensions
// model outputs 384 dimensions, so requesting 385 should fail
embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[385])")));
});
Tensor float16Result = embedder.embed(input, context, TensorType.fromSpec(("tensor<bfloat16>(x[1])")));
assertEquals(-0.666, float16Result.sum().asDouble(),1e-3);
Expand All @@ -131,6 +106,27 @@ public void testEmbedderWithNormalization() {
assertEquals("tensor<int8>(x[2]):[119, 44]", binarizedResult.toAbbreviatedString());
}

@Test
public void testByteQuantization() {
String input = "This is a test";
var context = new Embedder.Context("schema.indexing");

Tensor normalizedFloat = normalizedEmbedder.embed(input, context, TensorType.fromSpec("tensor<float>(x[64])"));
Tensor normalizedInt8 = normalizedEmbedder.embed(input, context, TensorType.fromSpec("tensor<int8>(x[64])"));
assertEquals(64, normalizedInt8.size());

// Verify values are in int8 range [-128, 127]
for (int i = 0; i < 64; i++) {
double int8Value = normalizedInt8.get(TensorAddress.of(i));
assertTrue(int8Value >= -128 && int8Value <= 127, "Value " + int8Value + " at index " + i + " in int8 range");

// Verify quantization is approximately correct (float * 127 ≈ int8)
double floatValue = normalizedFloat.get(TensorAddress.of(i));
double expectedInt8 = Math.round(floatValue * 127.0);
assertEquals(expectedInt8, int8Value, 1.0, "Quantization at index " + i);
}
}

@Test
public void testThatWrongTensorTypeThrows() {
var context = new Embedder.Context("schema.indexing");
Expand Down Expand Up @@ -270,13 +266,6 @@ private static HuggingFaceEmbedder getNormalizePrefixdEmbedder() {
return huggingFaceEmbedder;
}

public static Tensor expandBitTensor(Tensor packed) {
var unpacker = new UnpackBitsNode(new ReferenceNode("input"), TensorType.Value.DOUBLE, "big");
var context = new MapContext();
context.put("input", new TensorValue(packed));
return unpacker.evaluate(context).asTensor();
}

static class MockModelPathHelper implements ModelPathHelper {
Set<String> invokedPaths = new HashSet<>();

Expand Down