Skip to content

Commit d8e7e1d

Browse files
authored
Fixes #999, hanlde UTF16 surrogate charactors properly. (#1003)
Change-Id: I19e77cf5a8282bea901434041806eb102549ec0f
1 parent b0fe73a commit d8e7e1d

File tree

2 files changed

+45
-4
lines changed

2 files changed

+45
-4
lines changed

api/src/main/native/djl/utils.h

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,21 @@ inline std::string GetStringFromJString(JNIEnv* env, jstring jstr) {
2929
if (jstr == nullptr) {
3030
return std::string();
3131
}
32-
const char* c_str = env->GetStringUTFChars(jstr, JNI_FALSE);
33-
std::string str = std::string(c_str);
34-
env->ReleaseStringUTFChars(jstr, c_str);
32+
33+
// TODO: cache reflection to improve performance
34+
const jclass string_class = env->GetObjectClass(jstr);
35+
const jmethodID getbytes_method = env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
36+
37+
const jstring charset = env->NewStringUTF("UTF-8");
38+
const jbyteArray jbytes = (jbyteArray) env->CallObjectMethod(jstr, getbytes_method, charset);
39+
env->DeleteLocalRef(charset);
40+
41+
const jsize length = env->GetArrayLength(jbytes);
42+
jbyte* c_str = env->GetByteArrayElements(jbytes, NULL);
43+
std::string str = std::string(reinterpret_cast<const char *>(c_str), length);
44+
45+
env->ReleaseByteArrayElements(jbytes, c_str, RELEASE_MODE);
46+
env->DeleteLocalRef(jbytes);
3547
return str;
3648
}
3749

@@ -100,9 +112,23 @@ inline std::vector<std::string> GetVecFromJStringArray(JNIEnv* env, jobjectArray
100112
// String[]
101113
inline jobjectArray GetStringArrayFromVec(JNIEnv* env, const std::vector <std::string> &vec) {
102114
jobjectArray array = env->NewObjectArray(vec.size(), env->FindClass("Ljava/lang/String;"), nullptr);
115+
116+
// TODO: cache reflection to improve performance
117+
const jclass string_class = env->FindClass("java/lang/String");
118+
const jmethodID ctor = env->GetMethodID(string_class, "<init>", "([BLjava/lang/String;)V");
119+
const jstring charset = env->NewStringUTF("UTF-8");
120+
103121
for (int i = 0; i < vec.size(); ++i) {
104-
env->SetObjectArrayElement(array, i, env->NewStringUTF(vec[i].c_str()));
122+
const char* c_str = vec[i].c_str();
123+
int len = vec[i].length();
124+
auto jbytes = env->NewByteArray(len);
125+
env->SetByteArrayRegion(jbytes, 0, len, reinterpret_cast<const jbyte*>(c_str));
126+
jobject jstr = env->NewObject(string_class, ctor, jbytes, charset);
127+
env->DeleteLocalRef(jbytes);
128+
env->SetObjectArrayElement(array, i, jstr);
105129
}
130+
131+
env->DeleteLocalRef(charset);
106132
return array;
107133
}
108134

extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,21 @@ public void testTokenize() throws IOException {
5353
}
5454
}
5555

56+
@Test
57+
@SuppressWarnings("AvoidEscapedUnicodeCharacters")
58+
public void testUtf16Tokenize() throws IOException {
59+
if (System.getProperty("os.name").startsWith("Win")) {
60+
throw new SkipException("Skip windows test.");
61+
}
62+
Path modelPath = Paths.get("build/test/models/sententpiece_test_model.model");
63+
try (SpTokenizer tokenizer = new SpTokenizer(modelPath)) {
64+
String original = "\uD83D\uDC4B\uD83D\uDC4B";
65+
List<String> tokens = tokenizer.tokenize(original);
66+
List<String> expected = Arrays.asList("▁", "\uD83D\uDC4B\uD83D\uDC4B");
67+
Assert.assertEquals(tokens, expected);
68+
}
69+
}
70+
5671
@Test
5772
public void testEncodeDecode() throws IOException {
5873
if (System.getProperty("os.name").startsWith("Win")) {

0 commit comments

Comments
 (0)