Skip to content

Commit c9647b9

Browse files
author
Qing Lan
committed
update onnxruntime along with String tensor
Change-Id: Ie50ee7cd63864a14dc46ae7be0566f1f0b319931
1 parent 48cf663 commit c9647b9

File tree

6 files changed

+50
-4
lines changed

6 files changed

+50
-4
lines changed

api/src/main/java/ai/djl/ndarray/BaseNDManager.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ public NDArray create(String data) {
4949
throw new UnsupportedOperationException("Not supported!");
5050
}
5151

52+
/** {@inheritDoc} */
53+
@Override
54+
public NDArray create(String[] data) {
55+
throw new UnsupportedOperationException("Not supported!");
56+
}
57+
5258
/** {@inheritDoc} */
5359
@Override
5460
public NDArray create(Shape shape, DataType dataType) {

api/src/main/java/ai/djl/ndarray/NDManager.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,14 +235,21 @@ default NDArray create(boolean data) {
235235
}
236236

237237
/**
238-
* Creates and initializes a scalar {@link NDArray}. NDArray of String DataType only supports
239-
* scalar.
238+
* Creates and initializes a scalar {@link NDArray}.
240239
*
241240
* @param data the String data that needs to be set
242241
* @return a new instance of {@link NDArray}
243242
*/
244243
NDArray create(String data);
245244

245+
/**
246+
* Creates and initializes 1D {@link NDArray}.
247+
*
248+
* @param data the String data that needs to be set
249+
* @return a new instance of {@link NDArray}
250+
*/
251+
NDArray create(String[] data);
252+
246253
/**
247254
* Creates and initializes a 1D {@link NDArray}.
248255
*

gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pytorch_version=1.7.1
1313
tensorflow_version=2.3.1
1414
tflite_version=2.4.1
1515
dlr_version=1.6.0
16-
onnxruntime_version=1.5.2
16+
onnxruntime_version=1.7.0
1717
paddlepaddle_version=2.0.0
1818
sentencepiece_version=0.1.92
1919
fasttext_version=0.9.2

onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,26 @@ public OrtNDArray create(Buffer data, Shape shape, DataType dataType) {
6363
}
6464
}
6565

66+
/** {@inheritDoc} */
67+
@Override
68+
public NDArray create(String data) {
69+
return create(new String[] {data});
70+
}
71+
72+
/** {@inheritDoc} */
73+
@Override
74+
public NDArray create(String[] data) {
75+
return create(data, new Shape(data.length));
76+
}
77+
78+
public NDArray create(String[] data, Shape shape) {
79+
try {
80+
return new OrtNDArray(this, OrtUtils.toTensor(env, data, shape));
81+
} catch (OrtException e) {
82+
throw new EngineException(e);
83+
}
84+
}
85+
6686
/** {@inheritDoc} */
6787
@Override
6888
public NDArray zeros(Shape shape, DataType dataType) {

onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ public static OnnxTensor toTensor(
6363
}
6464
}
6565

66+
public static OnnxTensor toTensor(OrtEnvironment env, String[] inputs, Shape shape)
67+
throws OrtException {
68+
long[] sh = shape.getShape();
69+
return OnnxTensor.createTensor(env, inputs, sh);
70+
}
71+
6672
public static NDArray toNDArray(NDManager manager, OnnxTensor tensor) {
6773
if (manager instanceof OrtNDManager) {
6874
return ((OrtNDManager) manager).create(tensor);

tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,19 @@ public NDArray create(float data) {
149149
/** {@inheritDoc} */
150150
@Override
151151
public NDArray create(String data) {
152-
// create scalar tensor with float
153152
try (Tensor<TString> tensor = TString.scalarOf(data)) {
154153
return new TfNDArray(this, tensor);
155154
}
156155
}
157156

157+
/** {@inheritDoc} */
158+
@Override
159+
public NDArray create(String[] data) {
160+
try (Tensor<TString> tensor = TString.vectorOf(data)) {
161+
return new TfNDArray(this, tensor);
162+
}
163+
}
164+
158165
/** {@inheritDoc} */
159166
@Override
160167
public NDArray create(Shape shape, DataType dataType) {

0 commit comments

Comments
 (0)