Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import ai.djl.util.Preconditions;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -171,9 +172,7 @@ public NDArray stopGradient() {
/** {@inheritDoc} */
@Override
public String[] toStringArray() {
// TODO: Parse String Array from bytes[]
throw new UnsupportedOperationException(
"TensorFlow does not supporting printing String NDArray");
return new String[] {JavacppUtils.getString(getHandle(), StandardCharsets.UTF_8)};
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ public NDArray create(Shape shape, DataType dataType) {
}

/** {@inheritDoc} */
@SuppressWarnings({"unchecked", "try"})
@Override
public TfNDArray create(Buffer data, Shape shape, DataType dataType) {
int size = data.remaining();
Expand Down Expand Up @@ -99,6 +98,13 @@ public TfNDArray create(Buffer data, Shape shape, DataType dataType) {
return new TfNDArray(this, handle);
}

/** {@inheritDoc} */
@Override
public NDArray create(String data) {
TFE_TensorHandle handle = JavacppUtils.createStringTensor(data);
return new TfNDArray(this, handle);
}

/** {@inheritDoc} */
@Override
public final Engine getEngine() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
import com.google.protobuf.InvalidProtocolBufferException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.PointerScope;
Expand All @@ -45,6 +47,7 @@
import org.tensorflow.internal.c_api.TF_Session;
import org.tensorflow.internal.c_api.TF_SessionOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_TString;
import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.proto.framework.ConfigProto;
Expand All @@ -62,7 +65,7 @@ private JavacppUtils() {}
@SuppressWarnings({"unchecked", "try"})
public static SavedModelBundle loadSavedModelBundle(
String exportDir, String[] tags, ConfigProto config, RunOptions runOptions) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();

// allocate parameters for TF_LoadSessionFromSavedModel
Expand Down Expand Up @@ -141,7 +144,7 @@ public static TF_Tensor[] runSession(
int numInputs = inputTensorHandles.length;
int numOutputs = outputOpHandles.length;
int numTargets = targetOpHandles.length;
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
// TODO: check with sig-jvm if TF_Output here is freed
TF_Output inputs = new TF_Output(numInputs);
PointerPointer<TF_Tensor> inputValues = new PointerPointer<>(numInputs);
Expand Down Expand Up @@ -199,7 +202,7 @@ public static TF_Tensor[] runSession(
@SuppressWarnings({"unchecked", "try"})
public static TFE_Context createEagerSession(
boolean async, int devicePlacementPolicy, ConfigProto config) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
TFE_ContextOptions opts = TFE_ContextOptions.newContextOptions();
TF_Status status = TF_Status.newStatus();
if (config != null) {
Expand All @@ -218,7 +221,7 @@ public static TFE_Context createEagerSession(

@SuppressWarnings({"unchecked", "try"})
public static Device getDevice(TFE_TensorHandle handle) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
BytePointer pointer = tensorflow.TFE_TensorHandleDeviceName(handle, status);
String device = new String(pointer.getStringBytes(), StandardCharsets.UTF_8);
Expand All @@ -232,7 +235,7 @@ public static DataType getDataType(TFE_TensorHandle handle) {

@SuppressWarnings({"unchecked", "try"})
public static Shape getShape(TFE_TensorHandle handle) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
int numDims = tensorflow.TFE_TensorHandleNumDims(handle, status);
status.throwExceptionIfNotOK();
Expand All @@ -258,7 +261,7 @@ public static TF_Tensor createEmptyTFTensor(Shape shape, DataType dataType) {

@SuppressWarnings({"unchecked", "try"})
public static TFE_TensorHandle createEmptyTFETensor(Shape shape, DataType dataType) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
TF_Tensor tensor = createEmptyTFTensor(shape, dataType);
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
Expand All @@ -267,13 +270,36 @@ public static TFE_TensorHandle createEmptyTFETensor(Shape shape, DataType dataTy
}
}

@SuppressWarnings({"unchecked", "try"})
public static TFE_TensorHandle createStringTensor(String src) {
int dType = TfDataType.toTf(DataType.STRING);
long[] dims = {};
long numBytes = Loader.sizeof(TF_TString.class);
try (PointerScope ignored = new PointerScope()) {
TF_Tensor tensor = AbstractTF_Tensor.allocateTensor(dType, dims, numBytes);
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(numBytes);
TF_TString data = new TF_TString(pointer).capacity(pointer.position() + 1);
byte[] buf = src.getBytes(StandardCharsets.UTF_8);
tensorflow.TF_TString_Copy(data, new BytePointer(buf), buf.length);
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
status.throwExceptionIfNotOK();
return handle.retainReference();
}
}

@SuppressWarnings({"unchecked", "try"})
public static TFE_TensorHandle createTFETensorFromByteBuffer(
ByteBuffer buf, Shape shape, DataType dataType) {
int dType = TfDataType.toTf(dataType);
long[] dims = shape.getShape();
long numBytes = shape.size() * dataType.getNumOfBytes();
try (PointerScope scope = new PointerScope()) {
long numBytes;
if (dataType == DataType.STRING) {
numBytes = buf.remaining() + 1;
} else {
numBytes = shape.size() * dataType.getNumOfBytes();
}
try (PointerScope ignored = new PointerScope()) {
TF_Tensor tensor = AbstractTF_Tensor.allocateTensor(dType, dims, numBytes);
// get data pointer in native engine
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(numBytes);
Expand All @@ -287,7 +313,7 @@ public static TFE_TensorHandle createTFETensorFromByteBuffer(

@SuppressWarnings({"unchecked", "try"})
public static TF_Tensor resolveTFETensor(TFE_TensorHandle handle) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
status.throwExceptionIfNotOK();
Expand All @@ -297,17 +323,34 @@ public static TF_Tensor resolveTFETensor(TFE_TensorHandle handle) {

@SuppressWarnings({"unchecked", "try"})
public static TFE_TensorHandle createTFETensor(TF_Tensor handle) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle tensor = AbstractTFE_TensorHandle.newTensor(handle, status);
status.throwExceptionIfNotOK();
return tensor.retainReference();
}
}

@SuppressWarnings({"unchecked", "try"})
public static String getString(TFE_TensorHandle handle, Charset charset) {
try (PointerScope ignored = new PointerScope()) {
// convert to TF_Tensor
TF_Status status = TF_Status.newStatus();
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
status.throwExceptionIfNotOK();

Pointer pointer =
tensorflow.TF_TensorData(tensor).capacity(tensorflow.TF_TensorByteSize(tensor));

TF_TString data = new TF_TString(pointer).capacity(pointer.position() + 1);
BytePointer bp = tensorflow.TF_TString_GetDataPointer(data);
return bp.getString(charset);
}
}

@SuppressWarnings({"unchecked", "try"})
public static ByteBuffer getByteBuffer(TFE_TensorHandle handle) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
// convert to TF_Tensor
TF_Status status = TF_Status.newStatus();
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
Expand All @@ -328,7 +371,7 @@ public static ByteBuffer getByteBuffer(TFE_TensorHandle handle) {
@SuppressWarnings({"unchecked", "try"})
public static TFE_TensorHandle toDevice(
TFE_TensorHandle handle, TFE_Context eagerSessionHandle, Device device) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
String deviceName = toTfDevice(device);
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle newHandle =
Expand Down Expand Up @@ -372,8 +415,7 @@ public static String toTfDevice(Device device) {
} else if (device.getDeviceType().equals(Device.Type.GPU)) {
return "/device:GPU:" + device.getDeviceId();
} else {
throw new EngineException(
"Unknown device type to TensorFlow Engine: " + device.toString());
throw new EngineException("Unknown device type to TensorFlow Engine: " + device);
}
}
}