Skip to content

Commit f301f2b

Browse files
author
Jelmer Kuperus
committed
add convenience methods for creating empty index.
1 parent c5f19d7 commit f301f2b

File tree

8 files changed

+141
-17
lines changed

8 files changed

+141
-17
lines changed

hnswlib-core/src/main/java/com/github/jelmerk/knn/bruteforce/BruteForceIndex.java

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import com.github.jelmerk.knn.Item;
66
import com.github.jelmerk.knn.SearchResult;
77
import com.github.jelmerk.knn.util.ClassLoaderObjectInputStream;
8+
import com.github.jelmerk.knn.util.DummyComparator;
89

910
import java.io.*;
1011
import java.nio.file.Files;
@@ -26,6 +27,7 @@ public class BruteForceIndex<TId, TVector, TItem extends Item<TId, TVector>, TDi
2627

2728
private static final long serialVersionUID = 1L;
2829

30+
private final boolean immutable;
2931
private final int dimensions;
3032
private final DistanceFunction<TVector, TDistance> distanceFunction;
3133
private final Comparator<TDistance> distanceComparator;
@@ -34,6 +36,7 @@ public class BruteForceIndex<TId, TVector, TItem extends Item<TId, TVector>, TDi
3436
private final Map<TId, Long> deletedItemVersions;
3537

3638
private BruteForceIndex(BruteForceIndex.Builder<TVector, TDistance> builder) {
39+
this.immutable = builder.immutable;
3740
this.dimensions = builder.dimensions;
3841
this.distanceFunction = builder.distanceFunction;
3942
this.distanceComparator = builder.distanceComparator;
@@ -79,6 +82,9 @@ public int getDimensions() {
7982
*/
8083
@Override
8184
public boolean add(TItem item) {
85+
if (immutable) {
86+
throw new UnsupportedOperationException("Index is immutable");
87+
}
8288
if (item.dimensions() != dimensions) {
8389
throw new IllegalArgumentException("Item does not have dimensionality of : " + dimensions);
8490
}
@@ -286,7 +292,7 @@ public static <TId, TVector, TItem extends Item<TId, TVector>, TDistance> BruteF
286292
Builder <TVector, TDistance> newBuilder(int dimensions, DistanceFunction<TVector, TDistance> distanceFunction) {
287293

288294
Comparator<TDistance> distanceComparator = Comparator.naturalOrder();
289-
return new Builder<>(dimensions, distanceFunction, distanceComparator);
295+
return new Builder<>(false, dimensions, distanceFunction, distanceComparator);
290296
}
291297

292298
/**
@@ -301,7 +307,23 @@ Builder <TVector, TDistance> newBuilder(int dimensions, DistanceFunction<TVector
301307
*/
302308
public static <TVector, TDistance> Builder <TVector, TDistance> newBuilder(int dimensions, DistanceFunction<TVector, TDistance> distanceFunction, Comparator<TDistance> distanceComparator) {
303309

304-
return new Builder<>(dimensions, distanceFunction, distanceComparator);
310+
return new Builder<>(false, dimensions, distanceFunction, distanceComparator);
311+
}
312+
313+
/**
314+
* Creates an immutable empty index.
315+
*
316+
* @return the empty index
317+
* @param <TId> Type of the external identifier of an item
318+
* @param <TVector> Type of the vector to perform distance calculation on
319+
* @param <TItem> Type of items stored in the index
320+
* @param <TDistance> Type of distance between items (expect any numeric type: float, double, int, ..)
321+
*/
322+
public static <TId, TVector, TItem extends Item<TId, TVector>, TDistance> BruteForceIndex<TId, TVector, TItem, TDistance> empty() {
323+
BruteForceIndex.Builder<TVector, TDistance> builder = new BruteForceIndex.Builder<>(true,0, (DistanceFunction<TVector, TDistance>) (u, v) -> {
324+
throw new UnsupportedOperationException();
325+
}, new DummyComparator<>());
326+
return builder.build();
305327
}
306328

307329
/**
@@ -318,7 +340,10 @@ public static class Builder <TVector, TDistance> {
318340

319341
private final Comparator<TDistance> distanceComparator;
320342

321-
Builder(int dimensions, DistanceFunction<TVector, TDistance> distanceFunction, Comparator<TDistance> distanceComparator) {
343+
private final boolean immutable;
344+
345+
Builder(boolean immutable, int dimensions, DistanceFunction<TVector, TDistance> distanceFunction, Comparator<TDistance> distanceComparator) {
346+
this.immutable = immutable;
322347
this.dimensions = dimensions;
323348
this.distanceFunction = distanceFunction;
324349
this.distanceComparator = distanceComparator;

hnswlib-core/src/main/java/com/github/jelmerk/knn/hnsw/HnswIndex.java

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public class HnswIndex<TId, TVector, TItem extends Item<TId, TVector>, TDistance
3333
implements Index<TId, TVector, TItem, TDistance> {
3434

3535
private static final byte VERSION_1 = 0x01;
36+
private static final byte VERSION_2 = 0x02;
3637

3738
private static final long serialVersionUID = 1L;
3839

@@ -42,6 +43,7 @@ public class HnswIndex<TId, TVector, TItem extends Item<TId, TVector>, TDistance
4243
private Comparator<TDistance> distanceComparator;
4344
private MaxValueComparator<TDistance> maxValueDistanceComparator;
4445

46+
private boolean immutable;
4547
private int dimensions;
4648
private int maxItemCount;
4749
private int m;
@@ -74,6 +76,7 @@ public class HnswIndex<TId, TVector, TItem extends Item<TId, TVector>, TDistance
7476

7577
private HnswIndex(RefinedBuilder<TId, TVector, TItem, TDistance> builder) {
7678

79+
this.immutable = builder.immutable;
7780
this.dimensions = builder.dimensions;
7881
this.maxItemCount = builder.maxItemCount;
7982
this.distanceFunction = builder.distanceFunction;
@@ -202,6 +205,9 @@ public boolean remove(TId id, long version) {
202205
*/
203206
@Override
204207
public boolean add(TItem item) {
208+
if (immutable) {
209+
throw new UnsupportedOperationException("Index is immutable");
210+
}
205211
if (item.dimensions() != dimensions) {
206212
throw new IllegalArgumentException("Item does not have dimensionality of : " + dimensions);
207213
}
@@ -757,7 +763,7 @@ public void save(OutputStream out) throws IOException {
757763
}
758764

759765
private void writeObject(ObjectOutputStream oos) throws IOException {
760-
oos.writeByte(VERSION_1);
766+
oos.writeByte(VERSION_2);
761767
oos.writeInt(dimensions);
762768
oos.writeObject(distanceFunction);
763769
oos.writeObject(distanceComparator);
@@ -776,6 +782,7 @@ private void writeObject(ObjectOutputStream oos) throws IOException {
776782
writeMutableObjectLongMap(oos, deletedItemVersions);
777783
writeNodesArray(oos, nodes);
778784
oos.writeInt(entryPoint == null ? -1 : entryPoint.id);
785+
oos.writeBoolean(immutable);
779786
}
780787

781788
@SuppressWarnings("unchecked")
@@ -802,6 +809,8 @@ private void readObject(ObjectInputStream ois) throws IOException, ClassNotFound
802809
this.nodes = readNodesArray(ois, itemSerializer, maxM0, maxM);
803810

804811
int entrypointNodeId = ois.readInt();
812+
813+
this.immutable = version != VERSION_1 && ois.readBoolean();
805814
this.entryPoint = entrypointNodeId == -1 ? null : nodes.get(entrypointNodeId);
806815

807816
this.globalLock = new ReentrantLock();
@@ -1069,7 +1078,26 @@ public static <TVector, TDistance extends Comparable<TDistance>> Builder<TVector
10691078

10701079
Comparator<TDistance> distanceComparator = Comparator.naturalOrder();
10711080

1072-
return new Builder<>(dimensions, distanceFunction, distanceComparator, maxItemCount);
1081+
return new Builder<>(false, dimensions, distanceFunction, distanceComparator, maxItemCount);
1082+
}
1083+
1084+
/**
1085+
* Creates an immutable empty index.
1086+
*
1087+
* @return the empty index
1088+
* @param <TId> Type of the external identifier of an item
1089+
* @param <TVector> Type of the vector to perform distance calculation on
1090+
* @param <TItem> Type of items stored in the index
1091+
* @param <TDistance> Type of distance between items (expect any numeric type: float, double, int, ..)
1092+
*/
1093+
public static <TId, TVector, TItem extends Item<TId, TVector>, TDistance> HnswIndex<TId, TVector, TItem, TDistance> empty() {
1094+
Builder<TVector, TDistance> builder = new Builder<>(true, 0, new DistanceFunction<TVector, TDistance>() {
1095+
@Override
1096+
public TDistance distance(TVector u, TVector v) {
1097+
throw new UnsupportedOperationException();
1098+
}
1099+
}, new DummyComparator<>(), 0);
1100+
return builder.build();
10731101
}
10741102

10751103
/**
@@ -1089,7 +1117,7 @@ public static <TVector, TDistance> Builder<TVector, TDistance> newBuilder(
10891117
Comparator<TDistance> distanceComparator,
10901118
int maxItemCount) {
10911119

1092-
return new Builder<>(dimensions, distanceFunction, distanceComparator, maxItemCount);
1120+
return new Builder<>(false, dimensions, distanceFunction, distanceComparator, maxItemCount);
10931121
}
10941122

10951123
private int assignLevel(TId value, double lambda) {
@@ -1318,6 +1346,7 @@ public static abstract class BuilderBase<TBuilder extends BuilderBase<TBuilder,
13181346
public static final int DEFAULT_EF_CONSTRUCTION = 200;
13191347
public static final boolean DEFAULT_REMOVE_ENABLED = false;
13201348

1349+
boolean immutable;
13211350
int dimensions;
13221351
DistanceFunction<TVector, TDistance> distanceFunction;
13231352
Comparator<TDistance> distanceComparator;
@@ -1329,11 +1358,12 @@ public static abstract class BuilderBase<TBuilder extends BuilderBase<TBuilder,
13291358
int efConstruction = DEFAULT_EF_CONSTRUCTION;
13301359
boolean removeEnabled = DEFAULT_REMOVE_ENABLED;
13311360

1332-
BuilderBase(int dimensions,
1361+
BuilderBase(boolean immutable,
1362+
int dimensions,
13331363
DistanceFunction<TVector, TDistance> distanceFunction,
13341364
Comparator<TDistance> distanceComparator,
13351365
int maxItemCount) {
1336-
1366+
this.immutable = immutable;
13371367
this.dimensions = dimensions;
13381368
this.distanceFunction = distanceFunction;
13391369
this.distanceComparator = distanceComparator;
@@ -1417,12 +1447,13 @@ public static class Builder<TVector, TDistance> extends BuilderBase<Builder<TVec
14171447
* @param distanceFunction the distance function
14181448
* @param maxItemCount the maximum number of elements in the index
14191449
*/
1420-
Builder(int dimensions,
1450+
Builder(boolean immutable,
1451+
int dimensions,
14211452
DistanceFunction<TVector, TDistance> distanceFunction,
14221453
Comparator<TDistance> distanceComparator,
14231454
int maxItemCount) {
14241455

1425-
super(dimensions, distanceFunction, distanceComparator, maxItemCount);
1456+
super(immutable, dimensions, distanceFunction, distanceComparator, maxItemCount);
14261457
}
14271458

14281459
@Override
@@ -1440,7 +1471,7 @@ Builder<TVector, TDistance> self() {
14401471
* @return the builder
14411472
*/
14421473
public <TId, TItem extends Item<TId, TVector>> RefinedBuilder<TId, TVector, TItem, TDistance> withCustomSerializers(ObjectSerializer<TId> itemIdSerializer, ObjectSerializer<TItem> itemSerializer) {
1443-
return new RefinedBuilder<>(dimensions, distanceFunction, distanceComparator, maxItemCount, m, ef, efConstruction,
1474+
return new RefinedBuilder<>(immutable, dimensions, distanceFunction, distanceComparator, maxItemCount, m, ef, efConstruction,
14441475
removeEnabled, itemIdSerializer, itemSerializer);
14451476
}
14461477

@@ -1475,7 +1506,8 @@ public static class RefinedBuilder<TId, TVector, TItem extends Item<TId, TVector
14751506
private ObjectSerializer<TId> itemIdSerializer;
14761507
private ObjectSerializer<TItem> itemSerializer;
14771508

1478-
RefinedBuilder(int dimensions,
1509+
RefinedBuilder(boolean immutable,
1510+
int dimensions,
14791511
DistanceFunction<TVector, TDistance> distanceFunction,
14801512
Comparator<TDistance> distanceComparator,
14811513
int maxItemCount,
@@ -1486,7 +1518,7 @@ public static class RefinedBuilder<TId, TVector, TItem extends Item<TId, TVector
14861518
ObjectSerializer<TId> itemIdSerializer,
14871519
ObjectSerializer<TItem> itemSerializer) {
14881520

1489-
super(dimensions, distanceFunction, distanceComparator, maxItemCount);
1521+
super(immutable, dimensions, distanceFunction, distanceComparator, maxItemCount);
14901522

14911523
this.m = m;
14921524
this.ef = ef;
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package com.github.jelmerk.knn.util;
2+
3+
import java.io.Serializable;
4+
import java.util.Comparator;
5+
6+
/**
7+
* Implementation of {@link Comparator} that is serializable and throws {@link UnsupportedOperationException} when
8+
* compare is called. Useful as a dummy placeholder when you know it will never be called.
9+
*
10+
* @param <T> the type of objects that may be compared by this comparator
11+
*/
12+
public class DummyComparator<T> implements Comparator<T>, Serializable {
13+
14+
@Override
15+
public int compare(T o1, T o2) {
16+
throw new UnsupportedOperationException();
17+
}
18+
}

hnswlib-core/src/test/java/com/github/jelmerk/knn/bruteforce/BruteForceIndexTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ void saveAndLoadIndex() throws IOException {
152152
assertThat(loadedIndex.size(), is(1));
153153
}
154154

155+
@Test
156+
void createEmptyIndex() {
157+
BruteForceIndex<String, float[], TestItem, Float> index = BruteForceIndex.empty();
158+
assertThat(index.size(), is(0));
159+
}
160+
155161

156162
}
157163

hnswlib-core/src/test/java/com/github/jelmerk/knn/hnsw/HnswIndexTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,4 +215,10 @@ void saveAndLoadIndex() throws IOException {
215215

216216
assertThat(loadedIndex.size(), is(1));
217217
}
218+
219+
@Test
220+
void emptyIndexIsImmutable() {
221+
HnswIndex<String, float[], TestItem, Float> index = HnswIndex.empty();
222+
assertThat(index.size(), is(0));
223+
}
218224
}

hnswlib-scala/src/main/scala/com/github/jelmerk/knn/scalalike/bruteforce/BruteForceIndex.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,20 @@ object BruteForceIndex {
8787

8888
new BruteForceIndex[TId, TVector, TItem, TDistance](jIndex)
8989
}
90+
91+
/**
92+
* Creates an immutable empty index.
93+
*
94+
* @tparam TId Type of the external identifier of an item
95+
* @tparam TVector Type of the vector to perform distance calculation on
96+
* @tparam TItem Type of items stored in the index
97+
* @tparam TDistance Type of distance between items (expect any numeric type: float, double, int, ..)
98+
* @return the index
99+
*/
100+
def empty[TId, TVector, TItem <: Item[TId, TVector], TDistance]: BruteForceIndex[TId, TVector, TItem, TDistance] = {
101+
val jIndex: JBruteForceIndex[TId, TVector, TItem, TDistance] = JBruteForceIndex.empty()
102+
new BruteForceIndex(jIndex)
103+
}
90104
}
91105

92106
/**

hnswlib-scala/src/main/scala/com/github/jelmerk/knn/scalalike/hnsw/HnswIndex.scala

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,21 @@ object HnswIndex {
115115
if(removeEnabled) builder.withRemoveEnabled().build()
116116
else builder.build()
117117

118-
new HnswIndex[TId, TVector, TItem, TDistance](jIndex)
118+
new HnswIndex(jIndex)
119+
}
119120

121+
/**
122+
* Creates an immutable empty index.
123+
*
124+
* @tparam TId Type of the external identifier of an item
125+
* @tparam TVector Type of the vector to perform distance calculation on
126+
* @tparam TItem Type of items stored in the index
127+
* @tparam TDistance Type of distance between items (expect any numeric type: float, double, int, ..)
128+
* @return the index
129+
*/
130+
def empty[TId, TVector, TItem <: Item[TId, TVector], TDistance]: HnswIndex[TId, TVector, TItem, TDistance] = {
131+
val jIndex: JHnswIndex[TId, TVector, TItem, TDistance] = JHnswIndex.empty()
132+
new HnswIndex(jIndex)
120133
}
121134

122135
}
@@ -140,13 +153,18 @@ class HnswIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] private (d
140153
/**
141154
* This distance function.
142155
*/
143-
val distanceFunction: DistanceFunction[TVector, TDistance] = delegate
144-
.getDistanceFunction.asInstanceOf[ScalaDistanceFunctionAdapter[TVector, TDistance]].scalaFunction
156+
val distanceFunction: DistanceFunction[TVector, TDistance] = delegate.getDistanceFunction match {
157+
case a: ScalaDistanceFunctionAdapter[TVector, TDistance] => a.scalaFunction
158+
case f => (v1: TVector, v2: TVector) => f.distance(v1, v2)
159+
}
145160

146161
/**
147162
* The ordering used to compare distances
148163
*/
149-
val distanceOrdering: Ordering[TDistance] = delegate.getDistanceComparator.asInstanceOf[Ordering[TDistance]]
164+
val distanceOrdering: Ordering[TDistance] = delegate.getDistanceComparator match {
165+
case ordering: Ordering[TDistance] => ordering
166+
case c => (x: TDistance, y: TDistance) => c.compare(x, y)
167+
}
150168

151169
/**
152170
* The maximum number of items the index can hold.

hnswlib-scala/src/test/scala/com/github/jelmerk/knn/scalalike/hnsw/HnswIndexSpec.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,9 @@ class HnswIndexSpec extends AnyFunSuite {
187187
index.asExactIndex.size should be (1)
188188
}
189189

190+
test("creates an empty immutable index") {
191+
val index = HnswIndex.empty[String, Array[Float], TestItem, Float]
192+
index.size should be (0)
193+
}
194+
190195
}

0 commit comments

Comments
 (0)