Skip to content

Commit 7191dc3

Browse files
Jelmer Kuperusjelmerk
authored andcommitted
add convenience methods for creating empty index.
1 parent c5f19d7 commit 7191dc3

File tree

9 files changed

+165
-21
lines changed

9 files changed

+165
-21
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,13 @@ permissions:
44
checks: write
55

66
on:
7-
pull_request:
8-
paths-ignore:
9-
- '**.md'
107
push:
118
branches-ignore:
129
- '!master'
1310
tags-ignore:
1411
- 'v[0-9]+.[0-9]+.[0-9]+'
15-
12+
paths-ignore:
13+
- '**.md'
1614
jobs:
1715
ci-pipeline:
1816
runs-on: ubuntu-22.04

hnswlib-core/src/main/java/com/github/jelmerk/knn/bruteforce/BruteForceIndex.java

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import com.github.jelmerk.knn.Item;
66
import com.github.jelmerk.knn.SearchResult;
77
import com.github.jelmerk.knn.util.ClassLoaderObjectInputStream;
8+
import com.github.jelmerk.knn.util.DummyComparator;
89

910
import java.io.*;
1011
import java.nio.file.Files;
@@ -26,6 +27,7 @@ public class BruteForceIndex<TId, TVector, TItem extends Item<TId, TVector>, TDi
2627

2728
private static final long serialVersionUID = 1L;
2829

30+
private final boolean immutable;
2931
private final int dimensions;
3032
private final DistanceFunction<TVector, TDistance> distanceFunction;
3133
private final Comparator<TDistance> distanceComparator;
@@ -34,6 +36,7 @@ public class BruteForceIndex<TId, TVector, TItem extends Item<TId, TVector>, TDi
3436
private final Map<TId, Long> deletedItemVersions;
3537

3638
private BruteForceIndex(BruteForceIndex.Builder<TVector, TDistance> builder) {
39+
this.immutable = builder.immutable;
3740
this.dimensions = builder.dimensions;
3841
this.distanceFunction = builder.distanceFunction;
3942
this.distanceComparator = builder.distanceComparator;
@@ -79,6 +82,9 @@ public int getDimensions() {
7982
*/
8083
@Override
8184
public boolean add(TItem item) {
85+
if (immutable) {
86+
throw new UnsupportedOperationException("Index is immutable");
87+
}
8288
if (item.dimensions() != dimensions) {
8389
throw new IllegalArgumentException("Item does not have dimensionality of : " + dimensions);
8490
}
@@ -286,7 +292,7 @@ public static <TId, TVector, TItem extends Item<TId, TVector>, TDistance> BruteF
286292
Builder <TVector, TDistance> newBuilder(int dimensions, DistanceFunction<TVector, TDistance> distanceFunction) {
287293

288294
Comparator<TDistance> distanceComparator = Comparator.naturalOrder();
289-
return new Builder<>(dimensions, distanceFunction, distanceComparator);
295+
return new Builder<>(false, dimensions, distanceFunction, distanceComparator);
290296
}
291297

292298
/**
@@ -301,7 +307,23 @@ Builder <TVector, TDistance> newBuilder(int dimensions, DistanceFunction<TVector
301307
*/
302308
public static <TVector, TDistance> Builder <TVector, TDistance> newBuilder(int dimensions, DistanceFunction<TVector, TDistance> distanceFunction, Comparator<TDistance> distanceComparator) {
303309

304-
return new Builder<>(dimensions, distanceFunction, distanceComparator);
310+
return new Builder<>(false, dimensions, distanceFunction, distanceComparator);
311+
}
312+
313+
/**
314+
* Creates an immutable empty index.
315+
*
316+
* @return the empty index
317+
* @param <TId> Type of the external identifier of an item
318+
* @param <TVector> Type of the vector to perform distance calculation on
319+
* @param <TItem> Type of items stored in the index
320+
* @param <TDistance> Type of distance between items (expect any numeric type: float, double, int, ..)
321+
*/
322+
public static <TId, TVector, TItem extends Item<TId, TVector>, TDistance> BruteForceIndex<TId, TVector, TItem, TDistance> empty() {
323+
BruteForceIndex.Builder<TVector, TDistance> builder = new BruteForceIndex.Builder<>(true,0, (DistanceFunction<TVector, TDistance>) (u, v) -> {
324+
throw new UnsupportedOperationException();
325+
}, new DummyComparator<>());
326+
return builder.build();
305327
}
306328

307329
/**
@@ -318,7 +340,10 @@ public static class Builder <TVector, TDistance> {
318340

319341
private final Comparator<TDistance> distanceComparator;
320342

321-
Builder(int dimensions, DistanceFunction<TVector, TDistance> distanceFunction, Comparator<TDistance> distanceComparator) {
343+
private final boolean immutable;
344+
345+
Builder(boolean immutable, int dimensions, DistanceFunction<TVector, TDistance> distanceFunction, Comparator<TDistance> distanceComparator) {
346+
this.immutable = immutable;
322347
this.dimensions = dimensions;
323348
this.distanceFunction = distanceFunction;
324349
this.distanceComparator = distanceComparator;

hnswlib-core/src/main/java/com/github/jelmerk/knn/hnsw/HnswIndex.java

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public class HnswIndex<TId, TVector, TItem extends Item<TId, TVector>, TDistance
3333
implements Index<TId, TVector, TItem, TDistance> {
3434

3535
private static final byte VERSION_1 = 0x01;
36+
private static final byte VERSION_2 = 0x02;
3637

3738
private static final long serialVersionUID = 1L;
3839

@@ -42,6 +43,7 @@ public class HnswIndex<TId, TVector, TItem extends Item<TId, TVector>, TDistance
4243
private Comparator<TDistance> distanceComparator;
4344
private MaxValueComparator<TDistance> maxValueDistanceComparator;
4445

46+
private boolean immutable;
4547
private int dimensions;
4648
private int maxItemCount;
4749
private int m;
@@ -74,6 +76,7 @@ public class HnswIndex<TId, TVector, TItem extends Item<TId, TVector>, TDistance
7476

7577
private HnswIndex(RefinedBuilder<TId, TVector, TItem, TDistance> builder) {
7678

79+
this.immutable = builder.immutable;
7780
this.dimensions = builder.dimensions;
7881
this.maxItemCount = builder.maxItemCount;
7982
this.distanceFunction = builder.distanceFunction;
@@ -202,6 +205,9 @@ public boolean remove(TId id, long version) {
202205
*/
203206
@Override
204207
public boolean add(TItem item) {
208+
if (immutable) {
209+
throw new UnsupportedOperationException("Index is immutable");
210+
}
205211
if (item.dimensions() != dimensions) {
206212
throw new IllegalArgumentException("Item does not have dimensionality of : " + dimensions);
207213
}
@@ -757,7 +763,7 @@ public void save(OutputStream out) throws IOException {
757763
}
758764

759765
private void writeObject(ObjectOutputStream oos) throws IOException {
760-
oos.writeByte(VERSION_1);
766+
oos.writeByte(VERSION_2);
761767
oos.writeInt(dimensions);
762768
oos.writeObject(distanceFunction);
763769
oos.writeObject(distanceComparator);
@@ -776,6 +782,7 @@ private void writeObject(ObjectOutputStream oos) throws IOException {
776782
writeMutableObjectLongMap(oos, deletedItemVersions);
777783
writeNodesArray(oos, nodes);
778784
oos.writeInt(entryPoint == null ? -1 : entryPoint.id);
785+
oos.writeBoolean(immutable);
779786
}
780787

781788
@SuppressWarnings("unchecked")
@@ -802,6 +809,8 @@ private void readObject(ObjectInputStream ois) throws IOException, ClassNotFound
802809
this.nodes = readNodesArray(ois, itemSerializer, maxM0, maxM);
803810

804811
int entrypointNodeId = ois.readInt();
812+
813+
this.immutable = version != VERSION_1 && ois.readBoolean();
805814
this.entryPoint = entrypointNodeId == -1 ? null : nodes.get(entrypointNodeId);
806815

807816
this.globalLock = new ReentrantLock();
@@ -1069,7 +1078,26 @@ public static <TVector, TDistance extends Comparable<TDistance>> Builder<TVector
10691078

10701079
Comparator<TDistance> distanceComparator = Comparator.naturalOrder();
10711080

1072-
return new Builder<>(dimensions, distanceFunction, distanceComparator, maxItemCount);
1081+
return new Builder<>(false, dimensions, distanceFunction, distanceComparator, maxItemCount);
1082+
}
1083+
1084+
/**
1085+
* Creates an immutable empty index.
1086+
*
1087+
* @return the empty index
1088+
* @param <TId> Type of the external identifier of an item
1089+
* @param <TVector> Type of the vector to perform distance calculation on
1090+
* @param <TItem> Type of items stored in the index
1091+
* @param <TDistance> Type of distance between items (expect any numeric type: float, double, int, ..)
1092+
*/
1093+
public static <TId, TVector, TItem extends Item<TId, TVector>, TDistance> HnswIndex<TId, TVector, TItem, TDistance> empty() {
1094+
Builder<TVector, TDistance> builder = new Builder<>(true, 0, new DistanceFunction<TVector, TDistance>() {
1095+
@Override
1096+
public TDistance distance(TVector u, TVector v) {
1097+
throw new UnsupportedOperationException();
1098+
}
1099+
}, new DummyComparator<>(), 0);
1100+
return builder.build();
10731101
}
10741102

10751103
/**
@@ -1089,7 +1117,7 @@ public static <TVector, TDistance> Builder<TVector, TDistance> newBuilder(
10891117
Comparator<TDistance> distanceComparator,
10901118
int maxItemCount) {
10911119

1092-
return new Builder<>(dimensions, distanceFunction, distanceComparator, maxItemCount);
1120+
return new Builder<>(false, dimensions, distanceFunction, distanceComparator, maxItemCount);
10931121
}
10941122

10951123
private int assignLevel(TId value, double lambda) {
@@ -1318,6 +1346,7 @@ public static abstract class BuilderBase<TBuilder extends BuilderBase<TBuilder,
13181346
public static final int DEFAULT_EF_CONSTRUCTION = 200;
13191347
public static final boolean DEFAULT_REMOVE_ENABLED = false;
13201348

1349+
boolean immutable;
13211350
int dimensions;
13221351
DistanceFunction<TVector, TDistance> distanceFunction;
13231352
Comparator<TDistance> distanceComparator;
@@ -1329,11 +1358,12 @@ public static abstract class BuilderBase<TBuilder extends BuilderBase<TBuilder,
13291358
int efConstruction = DEFAULT_EF_CONSTRUCTION;
13301359
boolean removeEnabled = DEFAULT_REMOVE_ENABLED;
13311360

1332-
BuilderBase(int dimensions,
1361+
BuilderBase(boolean immutable,
1362+
int dimensions,
13331363
DistanceFunction<TVector, TDistance> distanceFunction,
13341364
Comparator<TDistance> distanceComparator,
13351365
int maxItemCount) {
1336-
1366+
this.immutable = immutable;
13371367
this.dimensions = dimensions;
13381368
this.distanceFunction = distanceFunction;
13391369
this.distanceComparator = distanceComparator;
@@ -1417,12 +1447,13 @@ public static class Builder<TVector, TDistance> extends BuilderBase<Builder<TVec
14171447
* @param distanceFunction the distance function
14181448
* @param maxItemCount the maximum number of elements in the index
14191449
*/
1420-
Builder(int dimensions,
1450+
Builder(boolean immutable,
1451+
int dimensions,
14211452
DistanceFunction<TVector, TDistance> distanceFunction,
14221453
Comparator<TDistance> distanceComparator,
14231454
int maxItemCount) {
14241455

1425-
super(dimensions, distanceFunction, distanceComparator, maxItemCount);
1456+
super(immutable, dimensions, distanceFunction, distanceComparator, maxItemCount);
14261457
}
14271458

14281459
@Override
@@ -1440,7 +1471,7 @@ Builder<TVector, TDistance> self() {
14401471
* @return the builder
14411472
*/
14421473
public <TId, TItem extends Item<TId, TVector>> RefinedBuilder<TId, TVector, TItem, TDistance> withCustomSerializers(ObjectSerializer<TId> itemIdSerializer, ObjectSerializer<TItem> itemSerializer) {
1443-
return new RefinedBuilder<>(dimensions, distanceFunction, distanceComparator, maxItemCount, m, ef, efConstruction,
1474+
return new RefinedBuilder<>(immutable, dimensions, distanceFunction, distanceComparator, maxItemCount, m, ef, efConstruction,
14441475
removeEnabled, itemIdSerializer, itemSerializer);
14451476
}
14461477

@@ -1475,7 +1506,8 @@ public static class RefinedBuilder<TId, TVector, TItem extends Item<TId, TVector
14751506
private ObjectSerializer<TId> itemIdSerializer;
14761507
private ObjectSerializer<TItem> itemSerializer;
14771508

1478-
RefinedBuilder(int dimensions,
1509+
RefinedBuilder(boolean immutable,
1510+
int dimensions,
14791511
DistanceFunction<TVector, TDistance> distanceFunction,
14801512
Comparator<TDistance> distanceComparator,
14811513
int maxItemCount,
@@ -1486,7 +1518,7 @@ public static class RefinedBuilder<TId, TVector, TItem extends Item<TId, TVector
14861518
ObjectSerializer<TId> itemIdSerializer,
14871519
ObjectSerializer<TItem> itemSerializer) {
14881520

1489-
super(dimensions, distanceFunction, distanceComparator, maxItemCount);
1521+
super(immutable, dimensions, distanceFunction, distanceComparator, maxItemCount);
14901522

14911523
this.m = m;
14921524
this.ef = ef;
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package com.github.jelmerk.knn.util;
2+
3+
import java.io.Serializable;
4+
import java.util.Comparator;
5+
6+
/**
7+
* Implementation of {@link Comparator} that is serializable and throws {@link UnsupportedOperationException} when
8+
* compare is called. Useful as a dummy placeholder when you know it will never be called.
9+
*
10+
* @param <T> the type of objects that may be compared by this comparator
11+
*/
12+
public class DummyComparator<T> implements Comparator<T>, Serializable {
13+
14+
@Override
15+
public int compare(T o1, T o2) {
16+
throw new UnsupportedOperationException();
17+
}
18+
}

hnswlib-core/src/test/java/com/github/jelmerk/knn/bruteforce/BruteForceIndexTest.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import static org.hamcrest.CoreMatchers.*;
1414
import static org.hamcrest.MatcherAssert.assertThat;
15+
import static org.junit.jupiter.api.Assertions.assertThrows;
1516

1617
class BruteForceIndexTest {
1718

@@ -152,6 +153,20 @@ void saveAndLoadIndex() throws IOException {
152153
assertThat(loadedIndex.size(), is(1));
153154
}
154155

156+
@Test
157+
void emptyIndexIsImmutable() {
158+
BruteForceIndex<String, float[], TestItem, Float> index = BruteForceIndex.empty();
159+
160+
assertThrows(
161+
UnsupportedOperationException.class,
162+
() -> index.add(item1),
163+
"Index should be immutable"
164+
);
165+
166+
assertThat(index.size(), is(0));
167+
assertThat(index.getDimensions(), is(0));
168+
}
169+
155170

156171
}
157172

hnswlib-core/src/test/java/com/github/jelmerk/knn/hnsw/HnswIndexTest.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import static org.hamcrest.CoreMatchers.*;
1414
import static org.hamcrest.MatcherAssert.assertThat;
15+
import static org.junit.jupiter.api.Assertions.assertThrows;
1516

1617
class HnswIndexTest {
1718

@@ -215,4 +216,18 @@ void saveAndLoadIndex() throws IOException {
215216

216217
assertThat(loadedIndex.size(), is(1));
217218
}
219+
220+
@Test
221+
void emptyIndexIsImmutable() {
222+
HnswIndex<String, float[], TestItem, Float> index = HnswIndex.empty();
223+
224+
assertThrows(
225+
UnsupportedOperationException.class,
226+
() -> index.add(item1),
227+
"Index should be immutable"
228+
);
229+
230+
assertThat(index.size(), is(0));
231+
assertThat(index.getDimensions(), is(0));
232+
}
218233
}

hnswlib-scala/src/main/scala/com/github/jelmerk/knn/scalalike/bruteforce/BruteForceIndex.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,20 @@ object BruteForceIndex {
8787

8888
new BruteForceIndex[TId, TVector, TItem, TDistance](jIndex)
8989
}
90+
91+
/**
92+
* Creates an immutable empty index.
93+
*
94+
* @tparam TId Type of the external identifier of an item
95+
* @tparam TVector Type of the vector to perform distance calculation on
96+
* @tparam TItem Type of items stored in the index
97+
* @tparam TDistance Type of distance between items (expect any numeric type: float, double, int, ..)
98+
* @return the index
99+
*/
100+
def empty[TId, TVector, TItem <: Item[TId, TVector], TDistance]: BruteForceIndex[TId, TVector, TItem, TDistance] = {
101+
val jIndex: JBruteForceIndex[TId, TVector, TItem, TDistance] = JBruteForceIndex.empty()
102+
new BruteForceIndex(jIndex)
103+
}
90104
}
91105

92106
/**

0 commit comments

Comments
 (0)