Skip to content

Commit 8479de8

Browse files
author
Soha Agarwal
committed
Fix preference of tokenizer_config.json and remove doLowerCase from TokenizerConfig
1 parent d1a7eb7 commit 8479de8

File tree

3 files changed

+66
-37
lines changed

3 files changed

+66
-37
lines changed

extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -100,31 +100,36 @@ private HuggingFaceTokenizer(
100100
modelMaxLength = 512;
101101
}
102102
if (config != null) {
103-
applyConfig(config);
103+
applyConfig(config, options);
104104
}
105105
updateTruncationAndPadding(padInfo);
106106
}
107107

108-
private void applyConfig(TokenizerConfig config) {
109-
this.modelMaxLength = config.getModelMaxLength();
110-
if (config.hasExplicitDoLowerCase() && config.isDoLowerCase()) {
111-
this.doLowerCase = Locale.getDefault();
108+
private void applyConfig(TokenizerConfig config, Map<String, String> options) {
109+
if (options != null && !options.containsKey("modelMaxLength")) {
110+
this.modelMaxLength = config.getModelMaxLength();
112111
}
113112
this.cleanupTokenizationSpaces = config.isCleanUpTokenizationSpaces();
114-
if (Stream.of(
115-
config.getBosToken(),
116-
config.getClsToken(),
117-
config.getEosToken(),
118-
config.getSepToken(),
119-
config.getUnkToken(),
120-
config.getPadToken())
121-
.anyMatch(token -> token != null && !token.isEmpty())) {
122-
this.addSpecialTokens = true;
113+
if (options != null && !options.containsKey("addSpecialTokens")) {
114+
115+
this.addSpecialTokens =
116+
Stream.of(
117+
config.getBosToken(),
118+
config.getClsToken(),
119+
config.getEosToken(),
120+
config.getSepToken(),
121+
config.getUnkToken(),
122+
config.getPadToken())
123+
.anyMatch(token -> token != null && !token.isEmpty());
123124
}
124-
if (config.hasExplicitStripAccents()) {
125+
if (options != null
126+
&& !options.containsKey("stripAccents")
127+
&& config.hasExplicitStripAccents()) {
125128
this.stripAccents = config.isStripAccents();
126129
}
127-
if (config.hasExplicitAddPrefixSpace()) {
130+
if (options != null
131+
&& !options.containsKey("addPrefixSpace")
132+
&& config.hasExplicitAddPrefixSpace()) {
128133
this.addPrefixSpace = config.isAddPrefixSpace();
129134
}
130135
}

extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/TokenizerConfig.java

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ public class TokenizerConfig {
4141
@SerializedName("model_max_length")
4242
private Integer modelMaxLength;
4343

44-
@SerializedName("do_lower_case")
45-
private Boolean doLowerCase;
46-
4744
@SerializedName("strip_accents")
4845
private Boolean stripAccents;
4946

@@ -103,15 +100,6 @@ public int getModelMaxLength() {
103100
return modelMaxLength;
104101
}
105102

106-
/**
107-
* Is do lower case boolean.
108-
*
109-
* @return the boolean
110-
*/
111-
public boolean isDoLowerCase() {
112-
return Boolean.TRUE.equals(doLowerCase);
113-
}
114-
115103
/**
116104
* Is strip accents boolean.
117105
*
@@ -202,15 +190,6 @@ public String getTokenizerClass() {
202190
return tokenizerClass;
203191
}
204192

205-
/**
206-
* Has explicit do lower case boolean.
207-
*
208-
* @return the boolean
209-
*/
210-
public boolean hasExplicitDoLowerCase() {
211-
return doLowerCase != null;
212-
}
213-
214193
/**
215194
* Has explicit strip accents boolean.
216195
*

extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,4 +645,49 @@ public void testConfigParameters() throws IOException {
645645
Assert.assertEquals(0, ids[0], "First token should have id 0 (<s>)");
646646
}
647647
}
648+
649+
@Test
650+
public void testPreferenceWhenBothOptionsAndConfigSet() throws IOException {
651+
try (HuggingFaceTokenizer tokenizer =
652+
HuggingFaceTokenizer.builder()
653+
.optMaxLength(48)
654+
.optAddSpecialTokens(false)
655+
.optTokenizerPath(
656+
Paths.get("src/test/resources/fake-tokenizer-with-null-padding/"))
657+
.optTokenizerConfigPath(
658+
"src/test/resources/fake-tokenizer-with-null-padding/tokenizer_config.json")
659+
.build()) {
660+
String input = "Hello World";
661+
Encoding encoding = tokenizer.encode(input); // with special tokens
662+
String[] tokens = encoding.getTokens();
663+
664+
// Verify special tokens from tokenizer.json are used
665+
Assert.assertEquals(tokens[0], "▁Hello"); // bos_token/cls_token
666+
Assert.assertEquals(
667+
tokens[tokens.length - 1],
668+
"▁World"); // Last actual token without special tokens
669+
670+
String[] testInputs = {
671+
"Hello World", // Basic text
672+
"Hello World", // Multiple spaces
673+
String.join(" ", Collections.nCopies(1000, "hello")), // Long text
674+
"résumé café", // Accented characters
675+
"Hello\nWorld", // Newlines
676+
"Hello World" // Extra spaces
677+
};
678+
679+
for (String testInput : testInputs) {
680+
encoding = tokenizer.encode(testInput);
681+
682+
// Verify encoding basics
683+
Assert.assertNotNull(encoding);
684+
Assert.assertNotNull(encoding.getIds());
685+
Assert.assertNotNull(encoding.getTokens());
686+
687+
// Verify model_max_length constraint
688+
Assert.assertTrue(
689+
encoding.getIds().length <= 48, "Encoding length should not exceed 48");
690+
}
691+
}
692+
}
648693
}

0 commit comments

Comments
 (0)