Skip to content

Commit 18da404

Browse files
author
Lőrinc
committed
Add IntArrayList to store tokens without boxing
Before: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 3.263 ± 0.286 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 2.688 ± 0.054 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 5.335 ± 0.106 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 5.277 ± 0.067 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 5.002 ± 0.091 s/op After: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 2.498 ± 0.019 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 2.223 ± 0.014 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 4.354 ± 0.122 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 4.341 ± 0.076 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 4.068 ± 0.020 s/op
1 parent f2ccd09 commit 18da404

21 files changed

Lines changed: 751 additions & 510 deletions

README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
[![javadoc](https://javadoc.io/badge2/com.knuddels/jtokkit/javadoc.svg)](https://javadoc.io/doc/com.knuddels/jtokkit)
77

88
Welcome to JTokkit, a Java tokenizer library designed for use with OpenAI models.
9+
910
```java
1011
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
1112
Encoding enc = registry.getEncoding(EncodingType.CL100K_BASE);
@@ -20,6 +21,7 @@ enc = registry.getEncodingForModel(ModelType.TEXT_EMBEDDING_ADA_002);
2021
For a quick getting started, see our [documentation](https://jtokkit.knuddels.de/).
2122

2223
## 📖 Introduction
24+
2325
JTokkit aims to be a fast and efficient tokenizer designed for use in natural
2426
language processing tasks using the OpenAI models. It provides an easy-to-use
2527
interface for tokenizing input text, for example for counting required tokens
@@ -42,7 +44,6 @@ and `cl100k_base`
4244

4345
✅ Fast and efficient performance
4446

45-
4647
🔨 Handling of special tokens during encoding (not started)
4748

4849
## 📊 Performance
@@ -54,6 +55,7 @@ JTokkit is between 2-3 times faster than a comparable tokenizer.
5455
For details on the benchmark, see the [benchmark](benchmark) directory.
5556

5657
## 🛠️ Installation
58+
5759
You can install JTokkit by adding the following dependency to your Maven project:
5860

5961
```xml
@@ -73,16 +75,17 @@ dependencies {
7375
```
7476

7577
## 🔰 Getting Started
78+
7679
To use JTokkit, simply create a new `EncodingRegistry` and use `getEncoding` to
7780
retrieve the encoding you want to use. You can then use the `encode` and
7881
`decode` methods to encode and decode text.
7982

8083
```java
8184
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
8285
Encoding enc = registry.getEncoding(EncodingType.CL100K_BASE);
83-
List<Integer> encoded = enc.encode("This is a sample sentence.");
86+
IntArrayList encoded = enc.encode("This is a sample sentence.");
8487
// encoded = [2028, 374, 264, 6205, 11914, 13]
85-
88+
8689
String decoded = enc.decode(encoded);
8790
// decoded = "This is a sample sentence."
8891

@@ -100,12 +103,15 @@ You may want to extend JTokkit to support custom encodings. To do so, you have t
100103
options:
101104

102105
1. Implement the `Encoding` interface and register it with the `EncodingRegistry`
106+
103107
```java
104108
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
105109
Encoding customEncoding = new CustomEncoding();
106110
registry.registerEncoding(customEncoding);
107111
```
112+
108113
2. Add new parameters for use with the existing BPE algorithm
114+
109115
```java
110116
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
111117
GptBytePairEncodingParams params = new GptBytePairEncodingParams(
@@ -122,6 +128,7 @@ them by using `registry.getEncoding("custom-name")`. See the JavaDoc for more
122128
details.
123129

124130
## 📄 License
131+
125132
JTokkit is licensed under the MIT License. See the
126133
[LICENSE](https://github.com/knuddelsgmbh/jtokkit/blob/main/LICENSE) file
127134
for more information.

benchmark/src/jmh/java/com/knuddels/jtokkit/AbstractBenchmark.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.knuddels.jtokkit;
22

33
import com.knuddels.jtokkit.api.Encoding;
4+
import com.knuddels.jtokkit.api.IntArrayList;
45
import org.openjdk.jmh.annotations.Benchmark;
56

67
import java.util.List;
@@ -34,5 +35,5 @@ public Object benchmarkCl100kBase(BenchmarkingState state) {
3435
* @param fileContents the file contents to encode
3536
* @return a list of encoded token lists
3637
*/
37-
protected abstract List<List<Integer>> encodeAll(Encoding encoding, List<String> fileContents);
38+
protected abstract List<IntArrayList> encodeAll(Encoding encoding, List<String> fileContents);
3839
}
Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,48 @@
11
package com.knuddels.jtokkit;
22

33
import com.knuddels.jtokkit.api.Encoding;
4+
import com.knuddels.jtokkit.api.IntArrayList;
5+
import org.openjdk.jmh.annotations.Scope;
6+
import org.openjdk.jmh.annotations.Setup;
7+
import org.openjdk.jmh.annotations.State;
8+
import org.openjdk.jmh.annotations.TearDown;
9+
410
import java.util.List;
511
import java.util.concurrent.CompletableFuture;
612
import java.util.concurrent.ExecutorService;
713
import java.util.concurrent.Executors;
814
import java.util.stream.Collectors;
9-
import org.openjdk.jmh.annotations.Scope;
10-
import org.openjdk.jmh.annotations.Setup;
11-
import org.openjdk.jmh.annotations.State;
12-
import org.openjdk.jmh.annotations.TearDown;
1315

1416
@State(Scope.Thread)
1517
public abstract class AbstractMultiThreadedBenchmark extends AbstractBenchmark {
1618

17-
private final int threads;
18-
private ExecutorService executor;
19+
private final int threads;
20+
private ExecutorService executor;
1921

20-
public AbstractMultiThreadedBenchmark(final int threads) {
21-
this.threads = threads;
22-
}
22+
public AbstractMultiThreadedBenchmark(int threads) {
23+
this.threads = threads;
24+
}
2325

24-
@Setup
25-
public void setup() {
26-
executor = Executors.newFixedThreadPool(threads);
27-
}
26+
@Setup
27+
public void setup() {
28+
executor = Executors.newFixedThreadPool(threads);
29+
}
2830

29-
@TearDown
30-
public void tearDown() {
31-
executor.shutdown();
32-
}
31+
@TearDown
32+
public void tearDown() {
33+
executor.shutdown();
34+
}
3335

34-
@Override
35-
protected List<List<Integer>> encodeAll(final Encoding encoding, final List<String> fileContents) {
36-
final var futures = fileContents.stream()
37-
.map(it -> CompletableFuture.supplyAsync(() -> encoding.encode(it), executor))
38-
.collect(Collectors.toList());
36+
@Override
37+
protected List<IntArrayList> encodeAll(Encoding encoding, List<String> fileContents) {
38+
var futures = fileContents.stream()
39+
.map(it -> CompletableFuture.supplyAsync(() -> encoding.encode(it), executor))
40+
.toList();
3941

40-
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();
42+
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();
4143

42-
return futures.stream()
43-
.map(CompletableFuture::join)
44-
.collect(Collectors.toList());
45-
}
44+
return futures.stream()
45+
.map(CompletableFuture::join)
46+
.collect(Collectors.toList());
47+
}
4648
}

benchmark/src/jmh/java/com/knuddels/jtokkit/SingleThreadedBenchmark.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.knuddels.jtokkit;
22

33
import com.knuddels.jtokkit.api.Encoding;
4+
import com.knuddels.jtokkit.api.IntArrayList;
45
import org.openjdk.jmh.annotations.Benchmark;
56

67
import java.util.List;
@@ -18,7 +19,7 @@ public int benchmarkCl100kBaseTokenCount(BenchmarkingState state) {
1819
}
1920

2021
@Override
21-
protected List<List<Integer>> encodeAll(final Encoding encoding, final List<String> fileContents) {
22+
protected List<IntArrayList> encodeAll(Encoding encoding, List<String> fileContents) {
2223
return fileContents.stream()
2324
.map(encoding::encode)
2425
.toList();

docs/docs/getting-started/usage.md

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ To use JTokkit, first create a new `EncodingRegistry`:
99
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
1010
```
1111

12-
Make sure to keep a reference to the registry, as the creation of the registry is expensive. Creating the registry loads the vocabularies from the classpath. The registry itself handles caching of the loaded encodings. It is thread-safe and can safely be used concurrently by multiple components.
12+
Make sure to keep a reference to the registry, as the creation of the registry is expensive. Creating the registry loads
13+
the vocabularies from the classpath. The registry itself handles caching of the loaded encodings. It is thread-safe and
14+
can safely be used concurrently by multiple components.
1315

14-
If you do not want to automatically load all vocabularies of all encodings on registry creation, you can use the following lazy loading registry.
16+
If you do not want to automatically load all vocabularies of all encodings on registry creation, you can use the
17+
following lazy loading registry.
1518

1619
```java
1720
EncodingRegistry registry = Encodings.newLazyEncodingRegistry();
@@ -45,7 +48,7 @@ Optional<Encoding> encoding = registry.getEncodingForModel("gpt_4");
4548
You can use an `Encoding` to encode and decode text:
4649

4750
```java
48-
List<Integer> encoded = encoding.encode("This is a sample sentence.");
51+
IntArrayList encoded = encoding.encode("This is a sample sentence.");
4952
// encoded = [2028, 374, 264, 6205, 11914, 13]
5053

5154
String decoded = encoding.decode(encoded);
@@ -56,7 +59,9 @@ The encoding is also fully thread-safe and can be used concurrently by multiple
5659

5760
:::info
5861

59-
Note that the library currently does not support encoding of special tokens. Special tokens are artificial tokens used to unlock capabilities from a model, such as fill-in-the-middle. If the `Encoding#encode` method encounters a special token in the input text, it will throw an `UnsupportedOperationException`.
62+
Note that the library currently does not support encoding of special tokens. Special tokens are artificial tokens used
63+
to unlock capabilities from a model, such as fill-in-the-middle. If the `Encoding#encode` method encounters a special
64+
token in the input text, it will throw an `UnsupportedOperationException`.
6065

6166
If you want to encode special tokens as if they were normal text, you can use `Encoding#encodeOrdinary` instead:
6267

@@ -72,7 +77,8 @@ encoding.encodeOrdinary("hello <|endoftext|> world");
7277

7378
## Counting tokens
7479

75-
If all you want is the amount of tokens the text encodes to, you can use the shorthand method `Encoding#countTokens` or `Encoding#countTokensOrdinary`:
80+
If all you want is the amount of tokens the text encodes to, you can use the shorthand method `Encoding#countTokens`
81+
or `Encoding#countTokensOrdinary`:
7682

7783
```java
7884
int tokenCount = encoding.countTokens("This is a sample sentence.");
@@ -84,16 +90,19 @@ int tokenCount = encoding.countTokensOrdinary("hello <|endoftext|> world");
8490

8591
## Encoding text with truncation
8692

87-
If you want to only encode up until a specified amount of `maxTokens` and truncate after that amount, you can use `Encoding#encode(String, int)` or `Encoding#encodeOrdinary(String, int)`. These methods will truncate the encoded tokens to the specified length. They will automatically handle unicode characters that were split in half by the truncation by removing those tokens from the end of the list.
93+
If you want to only encode up until a specified amount of `maxTokens` and truncate after that amount, you can
94+
use `Encoding#encode(String, int)` or `Encoding#encodeOrdinary(String, int)`. These methods will truncate the encoded
95+
tokens to the specified length. They will automatically handle unicode characters that were split in half by the
96+
truncation by removing those tokens from the end of the list.
8897

8998
```java
90-
List<Integer> encoded = encoding.encode("This is a sample sentence.", 3);
99+
IntArrayList encoded = encoding.encode("This is a sample sentence.", 3);
91100
// encoded = [2028, 374, 264]
92101

93102
String decoded = encoding.decode(encoded);
94103
// decoded = "This is a"
95104

96-
List<Integer> encoded = encoding.encode("I love 🍕", 4);
105+
IntArrayList encoded = encoding.encode("I love 🍕", 4);
97106
// encoded = [40, 3021]
98107

99108
String decoded = encoding.decode(encoded);

lib/src/main/java/com/knuddels/jtokkit/ByteArrayList.java

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44

55
public class ByteArrayList {
66
private byte[] array;
7-
private int size;
7+
private int size = 0;
88

99
public ByteArrayList() {
10-
array = new byte[10];
11-
size = 0;
10+
this(10);
11+
}
12+
13+
public ByteArrayList(int size) {
14+
array = new byte[size];
1215
}
1316

1417
public void clear() {
@@ -22,17 +25,65 @@ public void add(byte element) {
2225
array[size++] = element;
2326
}
2427

28+
public byte get(int index) {
29+
return array[index];
30+
}
31+
32+
public int set(int index, byte element) {
33+
int old = array[index];
34+
array[index] = element;
35+
return old;
36+
}
37+
2538
private void resize() {
26-
byte[] newArray = new byte[array.length * 2];
27-
System.arraycopy(array, 0, newArray, 0, size);
39+
ensureCapacity(Math.max(1, array.length) * 2);
40+
}
41+
42+
public void ensureCapacity(int targetSize) {
43+
if (targetSize <= size) {
44+
return;
45+
}
46+
byte[] newArray = new byte[targetSize];
47+
if (size > 0) {
48+
System.arraycopy(array, 0, newArray, 0, size);
49+
}
2850
array = newArray;
2951
}
3052

31-
int length() {
53+
public int size() {
3254
return size;
3355
}
3456

35-
public byte[] toByteArray() {
57+
public boolean isEmpty() {
58+
return size == 0;
59+
}
60+
61+
public byte[] toArray() {
3662
return Arrays.copyOf(array, size);
3763
}
64+
65+
@Override
66+
public boolean equals(Object o) {
67+
if (this == o) {
68+
return true;
69+
} else if (o == null || getClass() != o.getClass()) {
70+
return false;
71+
}
72+
ByteArrayList that = (ByteArrayList) o;
73+
for (int i = 0; i < size; i++) {
74+
if (array[i] != that.array[i]) {
75+
return false;
76+
}
77+
}
78+
return true;
79+
}
80+
81+
@Override
82+
public int hashCode() {
83+
int result = 1;
84+
for (int i = 0; i < size; i++) {
85+
result = 31 * result + array[i];
86+
}
87+
return result;
88+
}
3889
}

lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.knuddels.jtokkit.api.Encoding;
44
import com.knuddels.jtokkit.api.GptBytePairEncodingParams;
5+
import com.knuddels.jtokkit.api.IntArrayList;
56

67
import java.io.BufferedReader;
78
import java.io.IOException;
@@ -176,11 +177,11 @@ public Cl100kGptBytePairEncoding(GptBytePairEncodingParams params) {
176177
}
177178

178179
@Override
179-
int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, List<Integer> out) {
180+
int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, IntArrayList out) {
180181
int[] tokenCount = {0};
181-
ArrayList<Integer> ranks = new ArrayList<>();
182+
IntArrayList ranks = new IntArrayList();
182183
Cl100kParser.split(text, utf8BytesList -> {
183-
tokenCount[0] += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, utf8BytesList.toByteArray(), out, ranks);
184+
tokenCount[0] += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, utf8BytesList.toArray(), out, ranks);
184185
return tokenCount[0] >= maxTokenCount;
185186
});
186187
return tokenCount[0];

0 commit comments

Comments
 (0)