Skip to content

Commit 79a6720

Browse files
authored
Fixes #1024, Add back string tensor support (#1040)
Change-Id: I7b3f326669dec6739d24131d7507984390817226
1 parent 08fdab4 commit 79a6720

File tree

3 files changed

+65
-18
lines changed

3 files changed

+65
-18
lines changed

tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import ai.djl.util.Preconditions;
2727
import java.nio.Buffer;
2828
import java.nio.ByteBuffer;
29+
import java.nio.charset.StandardCharsets;
2930
import java.util.ArrayList;
3031
import java.util.Arrays;
3132
import java.util.List;
@@ -171,9 +172,7 @@ public NDArray stopGradient() {
171172
/** {@inheritDoc} */
172173
@Override
173174
public String[] toStringArray() {
174-
// TODO: Parse String Array from bytes[]
175-
throw new UnsupportedOperationException(
176-
"TensorFlow does not supporting printing String NDArray");
175+
return new String[] {JavacppUtils.getString(getHandle(), StandardCharsets.UTF_8)};
177176
}
178177

179178
/** {@inheritDoc} */

tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ public NDArray create(Shape shape, DataType dataType) {
6262
}
6363

6464
/** {@inheritDoc} */
65-
@SuppressWarnings({"unchecked", "try"})
6665
@Override
6766
public TfNDArray create(Buffer data, Shape shape, DataType dataType) {
6867
int size = data.remaining();
@@ -99,6 +98,13 @@ public TfNDArray create(Buffer data, Shape shape, DataType dataType) {
9998
return new TfNDArray(this, handle);
10099
}
101100

101+
/** {@inheritDoc} */
102+
@Override
103+
public NDArray create(String data) {
104+
TFE_TensorHandle handle = JavacppUtils.createStringTensor(data);
105+
return new TfNDArray(this, handle);
106+
}
107+
102108
/** {@inheritDoc} */
103109
@Override
104110
public final Engine getEngine() {

tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/javacpp/JavacppUtils.java

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
import com.google.protobuf.InvalidProtocolBufferException;
2424
import java.nio.ByteBuffer;
2525
import java.nio.ByteOrder;
26+
import java.nio.charset.Charset;
2627
import java.nio.charset.StandardCharsets;
2728
import java.util.regex.Matcher;
2829
import java.util.regex.Pattern;
2930
import org.bytedeco.javacpp.BytePointer;
31+
import org.bytedeco.javacpp.Loader;
3032
import org.bytedeco.javacpp.Pointer;
3133
import org.bytedeco.javacpp.PointerPointer;
3234
import org.bytedeco.javacpp.PointerScope;
@@ -45,6 +47,7 @@
4547
import org.tensorflow.internal.c_api.TF_Session;
4648
import org.tensorflow.internal.c_api.TF_SessionOptions;
4749
import org.tensorflow.internal.c_api.TF_Status;
50+
import org.tensorflow.internal.c_api.TF_TString;
4851
import org.tensorflow.internal.c_api.TF_Tensor;
4952
import org.tensorflow.internal.c_api.global.tensorflow;
5053
import org.tensorflow.proto.framework.ConfigProto;
@@ -62,7 +65,7 @@ private JavacppUtils() {}
6265
@SuppressWarnings({"unchecked", "try"})
6366
public static SavedModelBundle loadSavedModelBundle(
6467
String exportDir, String[] tags, ConfigProto config, RunOptions runOptions) {
65-
try (PointerScope scope = new PointerScope()) {
68+
try (PointerScope ignored = new PointerScope()) {
6669
TF_Status status = TF_Status.newStatus();
6770

6871
// allocate parameters for TF_LoadSessionFromSavedModel
@@ -141,7 +144,7 @@ public static TF_Tensor[] runSession(
141144
int numInputs = inputTensorHandles.length;
142145
int numOutputs = outputOpHandles.length;
143146
int numTargets = targetOpHandles.length;
144-
try (PointerScope scope = new PointerScope()) {
147+
try (PointerScope ignored = new PointerScope()) {
145148
// TODO: check with sig-jvm if TF_Output here is freed
146149
TF_Output inputs = new TF_Output(numInputs);
147150
PointerPointer<TF_Tensor> inputValues = new PointerPointer<>(numInputs);
@@ -199,7 +202,7 @@ public static TF_Tensor[] runSession(
199202
@SuppressWarnings({"unchecked", "try"})
200203
public static TFE_Context createEagerSession(
201204
boolean async, int devicePlacementPolicy, ConfigProto config) {
202-
try (PointerScope scope = new PointerScope()) {
205+
try (PointerScope ignored = new PointerScope()) {
203206
TFE_ContextOptions opts = TFE_ContextOptions.newContextOptions();
204207
TF_Status status = TF_Status.newStatus();
205208
if (config != null) {
@@ -218,7 +221,7 @@ public static TFE_Context createEagerSession(
218221

219222
@SuppressWarnings({"unchecked", "try"})
220223
public static Device getDevice(TFE_TensorHandle handle) {
221-
try (PointerScope scope = new PointerScope()) {
224+
try (PointerScope ignored = new PointerScope()) {
222225
TF_Status status = TF_Status.newStatus();
223226
BytePointer pointer = tensorflow.TFE_TensorHandleDeviceName(handle, status);
224227
String device = new String(pointer.getStringBytes(), StandardCharsets.UTF_8);
@@ -232,7 +235,7 @@ public static DataType getDataType(TFE_TensorHandle handle) {
232235

233236
@SuppressWarnings({"unchecked", "try"})
234237
public static Shape getShape(TFE_TensorHandle handle) {
235-
try (PointerScope scope = new PointerScope()) {
238+
try (PointerScope ignored = new PointerScope()) {
236239
TF_Status status = TF_Status.newStatus();
237240
int numDims = tensorflow.TFE_TensorHandleNumDims(handle, status);
238241
status.throwExceptionIfNotOK();
@@ -258,7 +261,7 @@ public static TF_Tensor createEmptyTFTensor(Shape shape, DataType dataType) {
258261

259262
@SuppressWarnings({"unchecked", "try"})
260263
public static TFE_TensorHandle createEmptyTFETensor(Shape shape, DataType dataType) {
261-
try (PointerScope scope = new PointerScope()) {
264+
try (PointerScope ignored = new PointerScope()) {
262265
TF_Tensor tensor = createEmptyTFTensor(shape, dataType);
263266
TF_Status status = TF_Status.newStatus();
264267
TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
@@ -267,13 +270,36 @@ public static TFE_TensorHandle createEmptyTFETensor(Shape shape, DataType dataTy
267270
}
268271
}
269272

273+
@SuppressWarnings({"unchecked", "try"})
274+
public static TFE_TensorHandle createStringTensor(String src) {
275+
int dType = TfDataType.toTf(DataType.STRING);
276+
long[] dims = {};
277+
long numBytes = Loader.sizeof(TF_TString.class);
278+
try (PointerScope ignored = new PointerScope()) {
279+
TF_Tensor tensor = AbstractTF_Tensor.allocateTensor(dType, dims, numBytes);
280+
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(numBytes);
281+
TF_TString data = new TF_TString(pointer).capacity(pointer.position() + 1);
282+
byte[] buf = src.getBytes(StandardCharsets.UTF_8);
283+
tensorflow.TF_TString_Copy(data, new BytePointer(buf), buf.length);
284+
TF_Status status = TF_Status.newStatus();
285+
TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
286+
status.throwExceptionIfNotOK();
287+
return handle.retainReference();
288+
}
289+
}
290+
270291
@SuppressWarnings({"unchecked", "try"})
271292
public static TFE_TensorHandle createTFETensorFromByteBuffer(
272293
ByteBuffer buf, Shape shape, DataType dataType) {
273294
int dType = TfDataType.toTf(dataType);
274295
long[] dims = shape.getShape();
275-
long numBytes = shape.size() * dataType.getNumOfBytes();
276-
try (PointerScope scope = new PointerScope()) {
296+
long numBytes;
297+
if (dataType == DataType.STRING) {
298+
numBytes = buf.remaining() + 1;
299+
} else {
300+
numBytes = shape.size() * dataType.getNumOfBytes();
301+
}
302+
try (PointerScope ignored = new PointerScope()) {
277303
TF_Tensor tensor = AbstractTF_Tensor.allocateTensor(dType, dims, numBytes);
278304
// get data pointer in native engine
279305
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(numBytes);
@@ -287,7 +313,7 @@ public static TFE_TensorHandle createTFETensorFromByteBuffer(
287313

288314
@SuppressWarnings({"unchecked", "try"})
289315
public static TF_Tensor resolveTFETensor(TFE_TensorHandle handle) {
290-
try (PointerScope scope = new PointerScope()) {
316+
try (PointerScope ignored = new PointerScope()) {
291317
TF_Status status = TF_Status.newStatus();
292318
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
293319
status.throwExceptionIfNotOK();
@@ -297,17 +323,34 @@ public static TF_Tensor resolveTFETensor(TFE_TensorHandle handle) {
297323

298324
@SuppressWarnings({"unchecked", "try"})
299325
public static TFE_TensorHandle createTFETensor(TF_Tensor handle) {
300-
try (PointerScope scope = new PointerScope()) {
326+
try (PointerScope ignored = new PointerScope()) {
301327
TF_Status status = TF_Status.newStatus();
302328
TFE_TensorHandle tensor = AbstractTFE_TensorHandle.newTensor(handle, status);
303329
status.throwExceptionIfNotOK();
304330
return tensor.retainReference();
305331
}
306332
}
307333

334+
@SuppressWarnings({"unchecked", "try"})
335+
public static String getString(TFE_TensorHandle handle, Charset charset) {
336+
try (PointerScope ignored = new PointerScope()) {
337+
// convert to TF_Tensor
338+
TF_Status status = TF_Status.newStatus();
339+
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
340+
status.throwExceptionIfNotOK();
341+
342+
Pointer pointer =
343+
tensorflow.TF_TensorData(tensor).capacity(tensorflow.TF_TensorByteSize(tensor));
344+
345+
TF_TString data = new TF_TString(pointer).capacity(pointer.position() + 1);
346+
BytePointer bp = tensorflow.TF_TString_GetDataPointer(data);
347+
return bp.getString(charset);
348+
}
349+
}
350+
308351
@SuppressWarnings({"unchecked", "try"})
309352
public static ByteBuffer getByteBuffer(TFE_TensorHandle handle) {
310-
try (PointerScope scope = new PointerScope()) {
353+
try (PointerScope ignored = new PointerScope()) {
311354
// convert to TF_Tensor
312355
TF_Status status = TF_Status.newStatus();
313356
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
@@ -328,7 +371,7 @@ public static ByteBuffer getByteBuffer(TFE_TensorHandle handle) {
328371
@SuppressWarnings({"unchecked", "try"})
329372
public static TFE_TensorHandle toDevice(
330373
TFE_TensorHandle handle, TFE_Context eagerSessionHandle, Device device) {
331-
try (PointerScope scope = new PointerScope()) {
374+
try (PointerScope ignored = new PointerScope()) {
332375
String deviceName = toTfDevice(device);
333376
TF_Status status = TF_Status.newStatus();
334377
TFE_TensorHandle newHandle =
@@ -372,8 +415,7 @@ public static String toTfDevice(Device device) {
372415
} else if (device.getDeviceType().equals(Device.Type.GPU)) {
373416
return "/device:GPU:" + device.getDeviceId();
374417
} else {
375-
throw new EngineException(
376-
"Unknown device type to TensorFlow Engine: " + device.toString());
418+
throw new EngineException("Unknown device type to TensorFlow Engine: " + device);
377419
}
378420
}
379421
}

0 commit comments

Comments
 (0)