Skip to content

Commit 2158e99

Browse files
authored
Minor fixes on duplicated code (#736)
* remove methods that already defined in the NDArrayAdapter Change-Id: I01cc03a7f5b427bf31c6b3fd8d2136f2a27fe93b * refactor toString Change-Id: Iea22b16e1daa9f759b55c1a8b8b85536482e551a * remove sparse NDArray Change-Id: Icb44096519775f54cb32cc768c14f49e33dc7ea5 * fix test Change-Id: Icef580ed77e7bba22864ce44577de3cba51e3e41
1 parent 43e5891 commit 2158e99

File tree

8 files changed

+27
-115
lines changed

8 files changed

+27
-115
lines changed

api/src/main/java/ai/djl/BaseModel.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,22 @@ public Path getModelPath() {
267267
return modelDir;
268268
}
269269

270+
/** {@inheritDoc} */
271+
@Override
272+
public String toString() {
273+
StringBuilder sb = new StringBuilder(200);
274+
sb.append("Model (\n\tName: ").append(modelName);
275+
if (modelDir != null) {
276+
sb.append("\n\tModel location: ").append(modelDir.toAbsolutePath());
277+
}
278+
sb.append("\n\tData Type: ").append(dataType);
279+
for (Map.Entry<String, String> entry : properties.entrySet()) {
280+
sb.append("\n\t").append(entry.getKey()).append(": ").append(entry.getValue());
281+
}
282+
sb.append("\n)");
283+
return sb.toString();
284+
}
285+
270286
/** {@inheritDoc} */
271287
@SuppressWarnings("deprecation")
272288
@Override

dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,6 @@ public void detach() {
116116
manager = DlrNDManager.getSystemManager();
117117
}
118118

119-
/** {@inheritDoc} */
120-
@Override
121-
public boolean hasGradient() {
122-
return false;
123-
}
124-
125-
/** {@inheritDoc} */
126-
@Override
127-
public NDArray stopGradient() {
128-
throw new UnsupportedOperationException("Not supported for DLR");
129-
}
130-
131119
/** {@inheritDoc} */
132120
@Override
133121
public ByteBuffer toByteBuffer() {

integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayCreationOpTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ public void testCreateCSRMatrix() {
7070
long[] indptr = {0, 2, 2, 3};
7171
long[] indices = {0, 2, 1};
7272
NDArray nd = manager.createCSR(buf, indptr, indices, new Shape(3, 4));
73-
float[] array = nd.toFloatArray();
73+
float[] array = nd.toDense().toFloatArray();
7474
Assert.assertEquals(array[0], expected[0]);
7575
Assert.assertEquals(array[2], expected[1]);
7676
Assert.assertEquals(array[9], expected[2]);
@@ -85,7 +85,7 @@ public void testCreateRowSparseMatrix() {
8585
FloatBuffer buf = FloatBuffer.wrap(expected);
8686
long[] indices = {0, 1, 3};
8787
NDArray nd = manager.createRowSparse(buf, new Shape(3, 2), indices, new Shape(4, 2));
88-
float[] array = nd.toFloatArray();
88+
float[] array = nd.toDense().toFloatArray();
8989
Assert.assertEquals(array[0], expected[0]);
9090
Assert.assertEquals(array[1], expected[1]);
9191
Assert.assertEquals(array[2], expected[2]);

mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -200,20 +200,4 @@ private void loadParameters(Path paramFile, Map<String, ?> options)
200200
dataType = paramNDlist.head().getDataType();
201201
logger.debug("MXNet Model {} ({}) loaded successfully.", paramFile, dataType);
202202
}
203-
204-
/** {@inheritDoc} */
205-
@Override
206-
public String toString() {
207-
StringBuilder sb = new StringBuilder(200);
208-
sb.append("Model (\n\tName: ").append(modelName);
209-
if (modelDir != null) {
210-
sb.append("\n\tModel location: ").append(modelDir.toAbsolutePath());
211-
}
212-
sb.append("\n\tData Type: ").append(dataType);
213-
for (Map.Entry<String, String> entry : properties.entrySet()) {
214-
sb.append("\n\t").append(entry.getKey()).append(": ").append(entry.getValue());
215-
}
216-
sb.append("\n)");
217-
return sb.toString();
218-
}
219203
}

mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,9 @@ public String[] toStringArray() {
290290
/** {@inheritDoc} */
291291
@Override
292292
public ByteBuffer toByteBuffer() {
293+
if (getSparseFormat() != SparseFormat.DENSE) {
294+
throw new IllegalStateException("Require Dense NDArray, actual " + getSparseFormat());
295+
}
293296
Shape sh = getShape();
294297
DataType dType = getDataType();
295298
long product = sh.size();

mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ public MxNDArray create(Pointer handle) {
8181
* @param fmt the sparse format to use
8282
* @return the created array
8383
*/
84-
public MxSparseNDArray create(Pointer handle, SparseFormat fmt) {
85-
return new MxSparseNDArray(this, handle, fmt);
84+
public MxNDArray create(Pointer handle, SparseFormat fmt) {
85+
return new MxNDArray(this, handle, fmt);
8686
}
8787

8888
/** {@inheritDoc} */
@@ -97,7 +97,7 @@ public MxNDArray create(Shape shape, DataType dataType) {
9797

9898
/** {@inheritDoc} */
9999
@Override
100-
public MxSparseNDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) {
100+
public MxNDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) {
101101
SparseFormat fmt = SparseFormat.CSR;
102102
DataType dataType = DataType.fromBuffer(data);
103103
MxNDArray indptrNd = create(new Shape(indptr.length), DataType.INT64);
@@ -113,7 +113,7 @@ public MxSparseNDArray createCSR(Buffer data, long[] indptr, long[] indices, Sha
113113
new DataType[] {indptrNd.getDataType(), indicesNd.getDataType()},
114114
new Shape[] {indptrNd.getShape(), indicesNd.getShape()},
115115
false);
116-
MxSparseNDArray sparse = create(handle, fmt);
116+
MxNDArray sparse = create(handle, fmt);
117117
MxNDArray dataNd = create(new Shape(data.remaining()), dataType);
118118
dataNd.set(data);
119119
JnaUtils.ndArraySyncCopyFromNdArray(sparse, dataNd, -1);
@@ -124,8 +124,7 @@ public MxSparseNDArray createCSR(Buffer data, long[] indptr, long[] indices, Sha
124124

125125
/** {@inheritDoc} */
126126
@Override
127-
public MxSparseNDArray createRowSparse(
128-
Buffer data, Shape dataShape, long[] indices, Shape shape) {
127+
public MxNDArray createRowSparse(Buffer data, Shape dataShape, long[] indices, Shape shape) {
129128
SparseFormat fmt = SparseFormat.ROW_SPARSE;
130129
DataType dataType = DataType.fromBuffer(data);
131130
MxNDArray indicesNd = create(new Shape(indices.length), DataType.INT64);
@@ -139,7 +138,7 @@ public MxSparseNDArray createRowSparse(
139138
new DataType[] {indicesNd.getDataType()},
140139
new Shape[] {indicesNd.getShape()},
141140
false);
142-
MxSparseNDArray sparse = create(handle, fmt);
141+
MxNDArray sparse = create(handle, fmt);
143142
MxNDArray dataNd = create(dataShape, dataType);
144143
dataNd.set(data);
145144
JnaUtils.ndArraySyncCopyFromNdArray(sparse, dataNd, -1);

mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSparseNDArray.java

Lines changed: 0 additions & 62 deletions
This file was deleted.

paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpModel.java

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,22 +96,6 @@ private String[] findModelFile(Path dir) {
9696
return null;
9797
}
9898

99-
/** {@inheritDoc} */
100-
@Override
101-
public String toString() {
102-
StringBuilder sb = new StringBuilder(200);
103-
sb.append("Model (\n\tName: ").append(modelName);
104-
if (modelDir != null) {
105-
sb.append("\n\tModel location: ").append(modelDir.toAbsolutePath());
106-
}
107-
sb.append("\n\tData Type: ").append(dataType);
108-
for (Map.Entry<String, String> entry : properties.entrySet()) {
109-
sb.append("\n\t").append(entry.getKey()).append(": ").append(entry.getValue());
110-
}
111-
sb.append("\n)");
112-
return sb.toString();
113-
}
114-
11599
/** {@inheritDoc} */
116100
@Override
117101
public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator) {

0 commit comments

Comments
 (0)