Skip to content

Commit 85d9e85

Browse files
authored
[huggingface] Adds CrossEncoderTranslator (#2817)
1 parent 23e07cf commit 85d9e85

File tree

6 files changed

+644
-0
lines changed

6 files changed

+644
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.modality.nlp.translator;
14+
15+
import ai.djl.modality.Input;
16+
import ai.djl.modality.Output;
17+
import ai.djl.ndarray.BytesSupplier;
18+
import ai.djl.ndarray.NDList;
19+
import ai.djl.translate.Batchifier;
20+
import ai.djl.translate.NoBatchifyTranslator;
21+
import ai.djl.translate.TranslateException;
22+
import ai.djl.translate.Translator;
23+
import ai.djl.translate.TranslatorContext;
24+
import ai.djl.util.JsonUtils;
25+
import ai.djl.util.PairList;
26+
import ai.djl.util.StringPair;
27+
28+
import com.google.gson.JsonElement;
29+
import com.google.gson.JsonParseException;
30+
31+
/** A {@link Translator} that can handle generic cross encoder {@link Input} and {@link Output}. */
32+
public class CrossEncoderServingTranslator implements NoBatchifyTranslator<Input, Output> {
33+
34+
private Translator<StringPair, float[]> translator;
35+
private Translator<StringPair[], float[][]> batchTranslator;
36+
37+
/**
38+
* Constructs a {@code CrossEncoderServingTranslator} instance.
39+
*
40+
* @param translator a {@code Translator} processes question answering input
41+
*/
42+
public CrossEncoderServingTranslator(Translator<StringPair, float[]> translator) {
43+
this.translator = translator;
44+
this.batchTranslator = translator.toBatchTranslator();
45+
}
46+
47+
/** {@inheritDoc} */
48+
@Override
49+
public void prepare(TranslatorContext ctx) throws Exception {
50+
translator.prepare(ctx);
51+
batchTranslator.prepare(ctx);
52+
}
53+
54+
/** {@inheritDoc} */
55+
@Override
56+
public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
57+
PairList<String, BytesSupplier> content = input.getContent();
58+
if (content.isEmpty()) {
59+
throw new TranslateException("Input data is empty.");
60+
}
61+
62+
String contentType = input.getProperty("Content-Type", null);
63+
StringPair pair;
64+
if ("application/json".equals(contentType)) {
65+
String json = input.getData().getAsString();
66+
try {
67+
JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class);
68+
if (element.isJsonArray()) {
69+
ctx.setAttachment("batch", Boolean.TRUE);
70+
StringPair[] inputs = JsonUtils.GSON.fromJson(json, StringPair[].class);
71+
return batchTranslator.processInput(ctx, inputs);
72+
}
73+
74+
pair = JsonUtils.GSON.fromJson(json, StringPair.class);
75+
if (pair.getKey() == null || pair.getValue() == null) {
76+
throw new TranslateException("Missing key or value in json.");
77+
}
78+
} catch (JsonParseException e) {
79+
throw new TranslateException("Input is not a valid json.", e);
80+
}
81+
} else {
82+
String key = input.getAsString("key");
83+
String value = input.getAsString("value");
84+
if (key == null || value == null) {
85+
throw new TranslateException("Missing key or value in input.");
86+
}
87+
pair = new StringPair(key, value);
88+
}
89+
90+
NDList ret = translator.processInput(ctx, pair);
91+
Batchifier batchifier = translator.getBatchifier();
92+
if (batchifier != null) {
93+
NDList[] batch = {ret};
94+
return batchifier.batchify(batch);
95+
}
96+
return ret;
97+
}
98+
99+
/** {@inheritDoc} */
100+
@Override
101+
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
102+
Output output = new Output();
103+
output.addProperty("Content-Type", "application/json");
104+
if (ctx.getAttachment("batch") != null) {
105+
output.add(BytesSupplier.wrapAsJson(batchTranslator.processOutput(ctx, list)));
106+
} else {
107+
Batchifier batchifier = translator.getBatchifier();
108+
if (batchifier != null) {
109+
list = batchifier.unbatchify(list)[0];
110+
}
111+
output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list)));
112+
}
113+
return output;
114+
}
115+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.util;
14+
15+
/** A class containing the string key-value pair. */
16+
public class StringPair extends Pair<String, String> {
17+
18+
/**
19+
* Constructs a {@code Pair} instance with key and value.
20+
*
21+
* @param key the key
22+
* @param value the value
23+
*/
24+
public StringPair(String key, String value) {
25+
super(key, value);
26+
}
27+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.huggingface.translator;
14+
15+
import ai.djl.huggingface.tokenizers.Encoding;
16+
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
17+
import ai.djl.ndarray.NDArray;
18+
import ai.djl.ndarray.NDList;
19+
import ai.djl.ndarray.NDManager;
20+
import ai.djl.translate.Batchifier;
21+
import ai.djl.translate.NoBatchifyTranslator;
22+
import ai.djl.translate.TranslateException;
23+
import ai.djl.translate.TranslatorContext;
24+
import ai.djl.util.PairList;
25+
import ai.djl.util.StringPair;
26+
27+
import java.util.Arrays;
28+
29+
/** The translator for Huggingface cross encoder model. */
30+
public class CrossEncoderBatchTranslator implements NoBatchifyTranslator<StringPair[], float[][]> {
31+
32+
private HuggingFaceTokenizer tokenizer;
33+
private boolean includeTokenTypes;
34+
private Batchifier batchifier;
35+
36+
CrossEncoderBatchTranslator(
37+
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
38+
this.tokenizer = tokenizer;
39+
this.includeTokenTypes = includeTokenTypes;
40+
this.batchifier = batchifier;
41+
}
42+
43+
/** {@inheritDoc} */
44+
@Override
45+
public NDList processInput(TranslatorContext ctx, StringPair[] inputs)
46+
throws TranslateException {
47+
NDManager manager = ctx.getNDManager();
48+
PairList<String, String> list = new PairList<>(Arrays.asList(inputs));
49+
Encoding[] encodings = tokenizer.batchEncode(list);
50+
NDList[] batch = new NDList[encodings.length];
51+
for (int i = 0; i < encodings.length; ++i) {
52+
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
53+
}
54+
return batchifier.batchify(batch);
55+
}
56+
57+
/** {@inheritDoc} */
58+
@Override
59+
public float[][] processOutput(TranslatorContext ctx, NDList list) {
60+
NDList[] batch = batchifier.unbatchify(list);
61+
float[][] ret = new float[batch.length][];
62+
for (int i = 0; i < batch.length; ++i) {
63+
NDArray logits = list.get(0);
64+
NDArray result = logits.getNDArrayInternal().sigmoid();
65+
ret[i] = result.toFloatArray();
66+
}
67+
return ret;
68+
}
69+
}
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.huggingface.translator;
14+
15+
import ai.djl.huggingface.tokenizers.Encoding;
16+
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
17+
import ai.djl.ndarray.NDArray;
18+
import ai.djl.ndarray.NDList;
19+
import ai.djl.translate.ArgumentsUtil;
20+
import ai.djl.translate.Batchifier;
21+
import ai.djl.translate.Translator;
22+
import ai.djl.translate.TranslatorContext;
23+
import ai.djl.util.StringPair;
24+
25+
import java.io.IOException;
26+
import java.util.Map;
27+
28+
/** The translator for Huggingface cross encoder model. */
29+
public class CrossEncoderTranslator implements Translator<StringPair, float[]> {
30+
31+
private HuggingFaceTokenizer tokenizer;
32+
private boolean includeTokenTypes;
33+
private Batchifier batchifier;
34+
35+
CrossEncoderTranslator(
36+
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
37+
this.tokenizer = tokenizer;
38+
this.includeTokenTypes = includeTokenTypes;
39+
this.batchifier = batchifier;
40+
}
41+
42+
/** {@inheritDoc} */
43+
@Override
44+
public Batchifier getBatchifier() {
45+
return batchifier;
46+
}
47+
48+
/** {@inheritDoc} */
49+
@Override
50+
public NDList processInput(TranslatorContext ctx, StringPair input) {
51+
Encoding encoding = tokenizer.encode(input.getKey(), input.getValue());
52+
ctx.setAttachment("encoding", encoding);
53+
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
54+
}
55+
56+
/** {@inheritDoc} */
57+
@Override
58+
public float[] processOutput(TranslatorContext ctx, NDList list) {
59+
NDArray logits = list.get(0);
60+
NDArray result = logits.getNDArrayInternal().sigmoid();
61+
return result.toFloatArray();
62+
}
63+
64+
/** {@inheritDoc} */
65+
@Override
66+
public CrossEncoderBatchTranslator toBatchTranslator(Batchifier batchifier) {
67+
tokenizer.enableBatch();
68+
return new CrossEncoderBatchTranslator(tokenizer, includeTokenTypes, batchifier);
69+
}
70+
71+
/**
72+
* Creates a builder to build a {@code CrossEncoderTranslator}.
73+
*
74+
* @param tokenizer the tokenizer
75+
* @return a new builder
76+
*/
77+
public static Builder builder(HuggingFaceTokenizer tokenizer) {
78+
return new Builder(tokenizer);
79+
}
80+
81+
/**
82+
* Creates a builder to build a {@code CrossEncoderTranslator}.
83+
*
84+
* @param tokenizer the tokenizer
85+
* @param arguments the models' arguments
86+
* @return a new builder
87+
*/
88+
public static Builder builder(HuggingFaceTokenizer tokenizer, Map<String, ?> arguments) {
89+
Builder builder = builder(tokenizer);
90+
builder.configure(arguments);
91+
92+
return builder;
93+
}
94+
95+
/** The builder for question answering translator. */
96+
public static final class Builder {
97+
98+
private HuggingFaceTokenizer tokenizer;
99+
private boolean includeTokenTypes;
100+
private Batchifier batchifier = Batchifier.STACK;
101+
102+
Builder(HuggingFaceTokenizer tokenizer) {
103+
this.tokenizer = tokenizer;
104+
}
105+
106+
/**
107+
* Sets if include token types for the {@link Translator}.
108+
*
109+
* @param includeTokenTypes true to include token types
110+
* @return this builder
111+
*/
112+
public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
113+
this.includeTokenTypes = includeTokenTypes;
114+
return this;
115+
}
116+
117+
/**
118+
* Sets the {@link Batchifier} for the {@link Translator}.
119+
*
120+
* @param batchifier true to include token types
121+
* @return this builder
122+
*/
123+
public Builder optBatchifier(Batchifier batchifier) {
124+
this.batchifier = batchifier;
125+
return this;
126+
}
127+
128+
/**
129+
* Configures the builder with the model arguments.
130+
*
131+
* @param arguments the model arguments
132+
*/
133+
public void configure(Map<String, ?> arguments) {
134+
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
135+
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
136+
optBatchifier(Batchifier.fromString(batchifierStr));
137+
}
138+
139+
/**
140+
* Builds the translator.
141+
*
142+
* @return the new translator
143+
* @throws IOException if I/O error occurs
144+
*/
145+
public CrossEncoderTranslator build() throws IOException {
146+
return new CrossEncoderTranslator(tokenizer, includeTokenTypes, batchifier);
147+
}
148+
}
149+
}

0 commit comments

Comments
 (0)