diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index ebb3148b285..57cabbd2da5 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -4592,6 +4592,57 @@ default NDArray oneHot(int depth) { return oneHot(depth, 1f, 0f, DataType.FLOAT32); } + /** + * Returns a one-hot {@code NDArray}. + * + *
Examples + * + *
+ * jshell> NDArray array = manager.create(new int[] {1, 0, 2, 0});
+ * jshell> array.oneHot(3);
+ * ND: (4, 3) cpu() float32
+ * [[0., 1., 0.],
+ * [1., 0., 0.],
+ * [0., 0., 1.],
+ * [1., 0., 0.],
+ * ]
+ * jshell> NDArray array = manager.create(new int[][] {{1, 0}, {1, 0}, {2, 0}});
+ * jshell> array.oneHot(3);
+ * ND: (3, 2, 3) cpu() float32
+ * [[[0., 1., 0.],
+ * [1., 0., 0.],
+ * ],
+ * [[0., 1., 0.],
+ * [1., 0., 0.],
+ * ],
+ * [[0., 0., 1.],
+ * [1., 0., 0.],
+ * ],
+ * ]
+ *
+ *
+ * @param depth Depth of the one hot dimension.
+ * @param dataType dataType of the output.
+ * @return one-hot encoding of this {@code NDArray}
+ * @see Classification-problems
+ */
+ default NDArray oneHot(int depth, DataType dataType) {
+ return oneHot(depth, 0f, 1f, dataType);
+ }
+
/**
* Returns a one-hot {@code NDArray}.
*
diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java
index 3e0ede826a5..246b9e57fee 100644
--- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java
+++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java
@@ -1408,6 +1408,18 @@ public NDArray norm(int order, int[] axes, boolean keepDims) {
return JniUtils.norm(this, order, axes, keepDims);
}
+ /** {@inheritDoc} */
+ @Override
+ public NDArray oneHot(int depth) {
+ return JniUtils.oneHot(this, depth, DataType.FLOAT32);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDArray oneHot(int depth, DataType dataType) {
+ return JniUtils.oneHot(this, depth, dataType);
+ }
+
/** {@inheritDoc} */
@Override
public NDArray oneHot(int depth, float onValue, float offValue, DataType dataType) {
diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java
index c98f4e60d75..64a8965f736 100644
--- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java
+++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java
@@ -720,6 +720,14 @@ public static PtNDArray cumSum(PtNDArray ndArray, long dim) {
ndArray.getManager(), PyTorchLibrary.LIB.torchCumSum(ndArray.getHandle(), dim));
}
+ public static PtNDArray oneHot(PtNDArray ndArray, int depth, DataType dataType) {
+ return new PtNDArray(
+ ndArray.getManager(),
+ PyTorchLibrary.LIB.torchNNOneHot(
+ ndArray.toType(DataType.INT64, false).getHandle(), depth))
+ .toType(dataType, false);
+ }
+
public static NDList split(PtNDArray ndArray, long size, long axis) {
long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), size, axis);
NDList list = new NDList();
diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java
index 2145ada3df9..0bb7dbafccf 100644
--- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java
+++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java
@@ -448,6 +448,8 @@ native long torchNNMaxPool(
native long torchNNLpPool(
long inputHandle, double normType, long[] kernelSize, long[] stride, boolean ceilMode);
+ native long torchNNOneHot(long inputHandle, int depth);
+
native boolean torchRequiresGrad(long inputHandle);
native String torchGradFnName(long inputHandle);
diff --git a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc
index 643bd20e5f8..66f2e171ba4 100644
--- a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc
+++ b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc
@@ -99,8 +99,8 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleWrite(
API_BEGIN()
auto* module_ptr = reinterpret_cast