Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package com.knuddels.jtokkit;


import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.infra.Blackhole;

import java.util.ArrayList;

public class Cl100kParserBenchmark {
@Benchmark
public void benchmarkIsLetter(BenchmarkingState state, Blackhole bh) {
for (var fileContent : state.fileContents) {
fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isLetter(cp)));
}
}

@Benchmark
public void benchmarkIsLetterOrNumeric(BenchmarkingState state, Blackhole bh) {
for (var fileContent : state.fileContents) {
fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isLetterOrNumeric(cp)));
}
}

@Benchmark
public void benchmarkIsNewline(BenchmarkingState state, Blackhole bh) {
for (var fileContent : state.fileContents) {
fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isNewline(cp)));
}
}

@Benchmark
public void benchmarkIsNotNewlineOrLetterOrNumeric(BenchmarkingState state, Blackhole bh) {
for (var fileContent : state.fileContents) {
fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isNotNewlineOrLetterOrNumeric(cp)));
}
}

@Benchmark
public void benchmarkIsNotWhitespaceOrLetterOrNumeric(BenchmarkingState state, Blackhole bh) {
for (var fileContent : state.fileContents) {
fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isNotWhitespaceOrLetterOrNumeric(cp)));
}
}

@Benchmark
public void benchmarkIsNumeric(BenchmarkingState state, Blackhole bh) {
for (var fileContent : state.fileContents) {
fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isNumeric(cp)));
}
}

@Benchmark
public void benchmarkIsWhitespace(BenchmarkingState state, Blackhole bh) {
for (var fileContent : state.fileContents) {
fileContent.codePoints().forEachOrdered(cp -> bh.consume(Cl100kParser.isWhitespace(cp)));
}
}

@Benchmark
public void benchmarkToUtf8Conversion(BenchmarkingState state, Blackhole bh) {
var dst = new ArrayList<Byte>();
for (var fileContent : state.fileContents) {
bh.consume(Cl100kParser.addUtf8Bytes(fileContent, 0, fileContent.length(), dst));
}
}
}
247 changes: 247 additions & 0 deletions lib/src/main/java/com/knuddels/jtokkit/Cl100kParser.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
package com.knuddels.jtokkit;


import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Predicate;

import static java.lang.Character.*;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Arrays.binarySearch;

public class Cl100kParser {
private static final String SDTM = "sdtmSDTMſ";
private static final String SIMPLE_WHITESPACES = "\t\n\u000B\u000C\r";
private static final int[] REMAINING_WHITESPACES = "\u1680\u2000\u2001\u2002\u2003\u2004\u2005\u2006\u2007\u2008\u2009\u200A\u2028\u2029\u202F\u205F\u3000".codePoints().sorted().toArray();

public static void split(String input, Predicate<List<Byte>> fragmentConsumer) {
assert isValidUTF8(input) : "Input is not UTF-8: " + input;
List<Byte> utf8Bytes = new ArrayList<>();
boolean finished = false;
for (int endIndex = 0; endIndex < input.length() && !finished; ) {
int startIndex = endIndex;
int c0 = input.codePointAt(startIndex);
int cc0 = charCount(c0);
int nextIndex = startIndex + cc0;
int c1 = (nextIndex < input.length()) ? input.codePointAt(nextIndex) : -1;

if ((c0 == '\'') && c1 > 0) {
if (isShortContraction(c1)) {
// 1) `'[sdtm]` - contractions, such as the suffixes of `he's`, `I'd`, `'tis`, `I'm`
endIndex += 2;
finished = fragmentConsumer.test(addUtf8Bytes(input, startIndex, endIndex, utf8Bytes));
continue;
} else if ((startIndex + 2) < input.length() && isLongContraction(c1, input.codePointAt(startIndex + 2))) {
// 1) `'(?:ll|ve|re)` - contractions, such as the suffixes of `you'll`, `we've`, `they're`
endIndex += 3;
finished = fragmentConsumer.test(addUtf8Bytes(input, startIndex, endIndex, utf8Bytes));
continue;
}
}

int cc1 = charCount(c1);
if ((isNotNewlineOrLetterOrNumeric(c0) && isLetter(c1)) || isLetter(c0)) {
// 2) `[^\r\n\p{L}\p{N}]?+\p{L}+` - words such as ` of`, `th`, `It`, ` not`
endIndex += cc0;
if (isLetter(c1)) {
endIndex += cc1;
while ((endIndex < input.length()) && isLetter(c0 = input.codePointAt(endIndex))) {
endIndex += charCount(c0);
}
}
finished = fragmentConsumer.test(addUtf8Bytes(input, startIndex, endIndex, utf8Bytes));
} else if (isNumeric(c0)) {
// 3) `\p{N}{1,3}` - numbers, such as `4`, `235` or `3½`
endIndex += cc0;
if (isNumeric(c1)) {
endIndex += cc1;
if ((endIndex < input.length()) && isNumeric(c0 = input.codePointAt(endIndex))) {
endIndex += charCount(c0);
}
}
finished = fragmentConsumer.test(addUtf8Bytes(input, startIndex, endIndex, utf8Bytes));
} else if (isNotWhitespaceOrLetterOrNumeric(c0) || ((c0 == ' ') && isNotWhitespaceOrLetterOrNumeric(c1))) {
// 4) ` ?[^\s\p{L}\p{N}]++[\r\n]*` - punctuation, such as `,`, ` .`, `"`
endIndex += cc0;
if ((endIndex < input.length()) && isNotWhitespaceOrLetterOrNumeric(c1)) {
endIndex += cc1;
while ((endIndex < input.length()) && isNotWhitespaceOrLetterOrNumeric(c0 = input.codePointAt(endIndex))) {
endIndex += charCount(c0);
}
}
while ((endIndex < input.length()) && isNewline(input.codePointAt(endIndex))) {
endIndex++;
}
finished = fragmentConsumer.test(addUtf8Bytes(input, startIndex, endIndex, utf8Bytes));
} else {
// 5) `\s*[\r\n]+` - line endings such as `\r\n \r\n`
// 6) `\s+(?!\S)` - whitespaces such as ` ` or ` `
// 7) `\s+` - unmatched remaining spaces, such as ` `
assert isWhitespace(c0) : "Invalid character: " + Arrays.toString(toChars(c0));
int lastNewLineIndex = isNewline(c0) ? endIndex : -1;
endIndex += cc0;
if (isWhitespace(c1)) {
lastNewLineIndex = isNewline(c1) ? endIndex : lastNewLineIndex;
endIndex += cc1;
while (endIndex < input.length() && isWhitespace(c0 = input.codePointAt(endIndex))) {
lastNewLineIndex = isNewline(c0) ? endIndex : lastNewLineIndex;
endIndex += charCount(c0);
}
}

if (lastNewLineIndex > -1) {
int finalEndIndex = endIndex;
endIndex = lastNewLineIndex + 1;
if (endIndex < finalEndIndex) {
assert startIndex < endIndex;
finished = fragmentConsumer.test(addUtf8Bytes(input, startIndex, endIndex, utf8Bytes));
startIndex = endIndex;
endIndex = finalEndIndex;
}
}
if (!finished) {
if (lastNewLineIndex + 1 < endIndex && !isWhitespace(c0)) {
endIndex--;
}
if (startIndex < endIndex) {
finished = fragmentConsumer.test(addUtf8Bytes(input, startIndex, endIndex, utf8Bytes));
}
}
}
}
}


static boolean isShortContraction(int ch) {
return SDTM.indexOf(ch) >= 0;
}

static boolean isLongContraction(int ch1, int ch2) {
if (((ch1 == 'l') && (ch2 == 'l'))
|| ((ch1 == 'v') && (ch2 == 'e'))
|| ((ch1 == 'r') && (ch2 == 'e'))) {
return true;
} else {
int lch1 = toUpperCase(ch1);
int lch2 = toUpperCase(ch2);
return ((lch1 == 'L') && (lch2 == 'L'))
|| ((lch1 == 'V') && (lch2 == 'E'))
|| ((lch1 == 'R') && (lch2 == 'E'));
}
}

public static boolean isValidUTF8(String input) {
return UTF_8.newEncoder().canEncode(input);
}

public static boolean isLetter(int ch) {
if (ch < 0xaa) {
return ((ch >= 'a') && (ch <= 'z'))
|| ((ch >= 'A') && (ch <= 'Z'));
} else if (ch <= 0x323af) {
switch (getType(ch)) {
case UPPERCASE_LETTER:
case LOWERCASE_LETTER:
case TITLECASE_LETTER:
case MODIFIER_LETTER:
case OTHER_LETTER:
return true;
}
}
return false;
}

public static boolean isNumeric(int ch) {
if (ch < 0xb2) {
return (ch >= '0') && (ch <= '9');
} else if (ch <= 0x1fbf9) {
switch (getType(ch)) {
case DECIMAL_DIGIT_NUMBER:
case LETTER_NUMBER:
case OTHER_NUMBER:
return true;
}
}
return false;
}

static boolean isLetterOrNumeric(int ch) {
if (ch < 0xaa) {
return ((ch >= 'a') && (ch <= 'z'))
|| ((ch >= 'A') && (ch <= 'Z'))
|| ((ch >= '0') && (ch <= '9'));
} else if (ch <= 0x323af) {
switch (getType(ch)) {
case UPPERCASE_LETTER:
case LOWERCASE_LETTER:
case TITLECASE_LETTER:
case MODIFIER_LETTER:
case OTHER_LETTER:
case DECIMAL_DIGIT_NUMBER:
case LETTER_NUMBER:
case OTHER_NUMBER:
return true;
}
}
return false;
}

public static boolean isWhitespace(int ch) {
if (ch <= '\r') {
return SIMPLE_WHITESPACES.indexOf(ch) >= 0;
} else if (ch < '\u0085') {
return ch == ' ';
} else {
return (ch == '\u0085')
|| (ch == '\u00A0')
|| ((ch >= '\u1680') && (ch <= '\u3000') && (binarySearch(REMAINING_WHITESPACES, ch) >= 0));
}
}

static boolean isNewline(int ch) {
return (ch == '\r')
|| (ch == '\n');
}

public static boolean isNotWhitespaceOrLetterOrNumeric(int ch) {
if (ch < '0') {
return ch >= 0 && ch != ' ' && (ch > '\r' || ch < '\t');
} else {
return !isLetterOrNumeric(ch) && !isWhitespace(ch);
}
}

public static boolean isNotNewlineOrLetterOrNumeric(int ch) {
if (ch < '0') {
return ch >= 0 && (ch == ' ' || !isNewline(ch));
} else {
return !isLetterOrNumeric(ch);
}
}

static List<Byte> addUtf8Bytes(String input, int start, int end, List<Byte> dst) {
dst.clear();
for (int i = start; i < end; i++) {
int cp = input.codePointAt(i);
if (cp < 0x80) {
dst.add((byte) cp);
} else if (cp < 0x800) {
dst.add((byte) (0xc0 | (cp >> 0x6)));
dst.add((byte) (0x80 | (cp & 0x3f)));
} else if (cp < MIN_SUPPLEMENTARY_CODE_POINT) {
dst.add((byte) (0xe0 | (cp >> 0xc)));
dst.add((byte) (0x80 | ((cp >> 0x6) & 0x3f)));
dst.add((byte) (0x80 | (cp & 0x3f)));
} else {
assert cp < (MAX_CODE_POINT + 1) : "Invalid code point: " + cp;
dst.add((byte) (0xf0 | (cp >> 0x12)));
dst.add((byte) (0x80 | ((cp >> 0xc) & 0x3f)));
dst.add((byte) (0x80 | ((cp >> 0x6) & 0x3f)));
dst.add((byte) (0x80 | (cp & 0x3f)));
i++;
}
}
return dst;
}
}
32 changes: 25 additions & 7 deletions lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,10 @@ public static Encoding p50kEdit() {
* @return an {@link Encoding} instance for the cl100k_base encoding
*/
public static Encoding cl100kBase() {
return fromPredefinedParameters(
"cl100k_base",
"'(?:[sdmt]|ll|ve|re)|[^\r\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]++[\r\n]*|\\s*[\r\n]|\\s+(?!\\S)|\\s+",
"/com/knuddels/jtokkit/cl100k_base.tiktoken",
SPECIAL_TOKENS_CL100K_BASE,
true
);
// "'(?:[sdmt]|ll|ve|re)|[^\r\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]++[\r\n]*|\\s*[\r\n]|\\s+(?!\\S)|\\s+"
Map<byte[], Integer> mergeableRanks = loadMergeableRanks("/com/knuddels/jtokkit/cl100k_base.tiktoken");
GptBytePairEncodingParams params = new GptBytePairEncodingParams("cl100k_base", null, mergeableRanks, SPECIAL_TOKENS_CL100K_BASE);
return new Cl100kGptBytePairEncoding(params);
}

/**
Expand Down Expand Up @@ -176,4 +173,25 @@ public static Map<byte[], Integer> loadMergeableRanks(String fileName) {
throw new IllegalStateException("Could not load " + fileName + " from resources", e);
}
}

private static class Cl100kGptBytePairEncoding extends GptBytePairEncoding {
public Cl100kGptBytePairEncoding(GptBytePairEncodingParams params) {
super(params);
}

@Override
int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, List<Integer> out) {
int[] tokenCount = {0};
ArrayList<Integer> ranks = new ArrayList<>();
Cl100kParser.split(text, utf8BytesList -> {
byte[] utf8Bytes = new byte[utf8BytesList.size()];
for (int i = 0; i < utf8BytesList.size(); i++) {
utf8Bytes[i] = utf8BytesList.get(i);
}
tokenCount[0] += encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, utf8Bytes, out, ranks);
return tokenCount[0] >= maxTokenCount;
});
return tokenCount[0];
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
*/
class GptBytePairEncoding implements Encoding {

final TokenEncoder encoder;
private final String name;
private final Pattern pattern;
private final TokenEncoder encoder;
private final SpecialEncoder specialEncoder;

/**
Expand Down Expand Up @@ -78,7 +78,7 @@ private EncodingResult encodeOrdinaryInternal(String text, int maxTokenCount, bo
// Make sure we didn't break the multibyte character
for (int tokensToRemove = 0; tokensToRemove <= out.size(); tokensToRemove++) {
int size = out.size() - tokensToRemove;
List<Integer> tokens = new ArrayList<>(size);
ArrayList<Integer> tokens = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
tokens.add(out.get(i));
}
Expand Down
Loading