Skip to content

Commit f40d14d

Browse files
andrewfayrespiyushghai
authored andcommitted
Java Inference api and SSD example (apache#12830)
* New Java inference API and SSD example * Adding license to java files and fixing SSD example * Fixing SSD example to point to ObjectDetector instead of ImageClassifier * Make scripts for object detector independent to os and hw cpu/gpu * Added API Docs to Java Inference API. Small fixes for PR * Cosmetic updates for API DOCS requested during PR * Attempt to fix the CI Javafx compiler issue * Migrate from Javafx to apache commons for Pair implementation * Removing javafx from pom file * Fixes to appease the ScalaStyle deity * Minor fix in SSD script and Readme * Added ObjectDetectorOutput which is a POJO for Object Detector to simplify the return type * Removing Apache Commons Immutable Pair * Adding license to new file * Minor style fixes * minor style fix * Updating to be in scala style and not explicitly declare some unnecessary variables
1 parent ae10a40 commit f40d14d

8 files changed

Lines changed: 586 additions & 3 deletions

File tree

scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,21 @@
1717
# specific language governing permissions and limitations
1818
# under the License.
1919

20+
hw_type=cpu
21+
if [[ $1 = gpu ]]
22+
then
23+
hw_type=gpu
24+
fi
25+
26+
platform=linux-x86_64
27+
28+
if [[ $OSTYPE = [darwin]* ]]
29+
then
30+
platform=osx-x86_64
31+
fi
2032

2133
MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd)
22-
CLASS_PATH=$MXNET_ROOT/scala-package/assembly/osx-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
34+
CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
2335

2436
# model dir and prefix
2537
MODEL_DIR=$1
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#!/bin/bash
2+
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
20+
hw_type=cpu
21+
if [[ $4 = gpu ]]
22+
then
23+
hw_type=gpu
24+
fi
25+
26+
platform=linux-x86_64
27+
28+
if [[ $OSTYPE = [darwin]* ]]
29+
then
30+
platform=osx-x86_64
31+
fi
32+
33+
MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd)
34+
CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*:$MXNET_ROOT/scala-package/examples/src/main/scala/org/apache/mxnetexamples/api/java/infer/imageclassifier/*
35+
36+
# model dir and prefix
37+
MODEL_DIR=$1
38+
# input image
39+
INPUT_IMG=$2
40+
# which input image dir
41+
INPUT_DIR=$3
42+
43+
java -Xmx8G -cp $CLASS_PATH \
44+
org.apache.mxnetexamples.infer.javapi.objectdetector.SSDClassifierExample \
45+
--model-path-prefix $MODEL_DIR \
46+
--input-image $INPUT_IMG \
47+
--input-dir $INPUT_DIR
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Single Shot Multi Object Detection using Scala Inference API
2+
3+
In this example, you will learn how to use Scala Inference API to run Inference on pre-trained Single Shot Multi Object Detection (SSD) MXNet model.
4+
5+
The model is trained on the [Pascal VOC 2012 dataset](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html). The network is a SSD model built on Resnet50 as base network to extract image features. The model is trained to detect the following entities (classes): ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']. For more details about the model, you can refer to the [MXNet SSD example](https://github.com/apache/incubator-mxnet/tree/master/example/ssd).
6+
7+
8+
## Contents
9+
10+
1. [Prerequisites](#prerequisites)
11+
2. [Download artifacts](#download-artifacts)
12+
3. [Setup datapath and parameters](#setup-datapath-and-parameters)
13+
4. [Run the image inference example](#run-the-image-inference-example)
14+
5. [Infer APIs](#infer-api-details)
15+
6. [Next steps](#next-steps)
16+
17+
18+
## Prerequisites
19+
20+
1. MXNet
21+
2. MXNet Scala Package
22+
3. [IntelliJ IDE (or alternative IDE) project setup](http://mxnet.incubator.apache.org/tutorials/scala/mxnet_scala_on_intellij.html) with the MXNet Scala Package
23+
4. wget
24+
25+
26+
## Setup Guide
27+
28+
### Download Artifacts
29+
#### Step 1
30+
You can download the files using the script `get_ssd_data.sh`. It will download and place the model files in a `model` folder and the test image files in a `image` folder in the current directory.
31+
From the `scala-package/examples/scripts/infer/imageclassifier/` folder run:
32+
33+
```bash
34+
./get_ssd_data.sh
35+
```
36+
37+
**Note**: You may need to run `chmod +x get_resnet_data.sh` before running this script.
38+
39+
Alternatively use the following links to download the Symbol and Params files via your browser:
40+
- [resnet50_ssd_model-symbol.json](https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-symbol.json)
41+
- [resnet50_ssd_model-0000.params](https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-0000.params)
42+
- [synset.txt](https://github.com/awslabs/mxnet-model-server/blob/master/examples/ssd/synset.txt)
43+
44+
In the pre-trained model, the `input_name` is `data` and shape is `(1, 3, 512, 512)`.
45+
This shape translates to: a batch of `1` image, the image has color and uses `3` channels (RGB), and the image has the dimensions of `512` pixels in height by `512` pixels in width.
46+
47+
`image/jpeg` is the expected input type, since this example's image pre-processor only supports the handling of binary JPEG images.
48+
49+
The output shape is `(1, 6132, 6)`. As with the input, the `1` is the number of images. `6132` is the number of prediction results, and `6` is for the size of each prediction. Each prediction contains the following components:
50+
- `Class`
51+
- `Accuracy`
52+
- `Xmin`
53+
- `Ymin`
54+
- `Xmax`
55+
- `Ymax`
56+
57+
58+
### Setup Datapath and Parameters
59+
#### Step 2
60+
The code `Line 31: val baseDir = System.getProperty("user.dir")` in the example will automatically searches the work directory you have defined. Please put the files in your [work directory](https://stackoverflow.com/questions/16239130/java-user-dir-property-what-exactly-does-it-mean). <!-- how do you define the work directory? -->
61+
62+
Alternatively, if you would like to use your own path, please change line 31 into your own path
63+
```scala
64+
val baseDir = <Your Own Path>
65+
```
66+
67+
The followings is the parameters defined for this example, you can find more information in the `class SSDClassifierExample`.
68+
69+
| Argument | Comments |
70+
| ----------------------------- | ---------------------------------------- |
71+
| `model-path-prefix` | Folder path with prefix to the model (including json, params, and any synset file). |
72+
| `input-image` | The image to run inference on. |
73+
| `input-dir` | The directory of images to run inference on. |
74+
75+
76+
## How to Run Inference
77+
After the previous steps, you should be able to run the code using the following script that will pass all of the required parameters to the Infer API.
78+
79+
From the `scala-package/examples/scripts/inferexample/objectdetector/` folder run:
80+
81+
```bash
82+
./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
83+
```
84+
85+
**Notes**:
86+
* These are relative paths to this script.
87+
* You may need to run `chmod +x run_ssd_example.sh` before running this script.
88+
89+
The example should give expected output as shown below:
90+
```
91+
Class: car
92+
Probabilties: 0.99847263
93+
(Coord:,312.21335,72.0291,456.01443,150.66176)
94+
Class: bicycle
95+
Probabilties: 0.90473825
96+
(Coord:,155.95807,149.96362,383.8369,418.94513)
97+
Class: dog
98+
Probabilties: 0.8226818
99+
(Coord:,83.82353,179.13998,206.63783,476.7875)
100+
```
101+
the outputs come from the the input image, with top3 predictions picked.
102+
103+
104+
## Infer API Details
105+
This example uses ObjectDetector class provided by MXNet's scala package Infer APIs. It provides methods to load the images, create NDArray out of Java BufferedImage and run prediction using Classifier and Predictor APIs.
106+
107+
108+
## References
109+
This documentation used the model and inference setup guide from the [MXNet Model Server SSD example](https://github.com/awslabs/mxnet-model-server/blob/master/examples/ssd/README.md).
110+
111+
112+
## Next Steps
113+
114+
Check out the following related tutorials and examples for the Infer API:
115+
116+
* [Image Classification with the MXNet Scala Infer API](../imageclassifier/README.md)
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.mxnetexamples.infer.javapi.objectdetector;
19+
20+
import org.apache.mxnet.infer.javaapi.ObjectDetectorOutput;
21+
import org.kohsuke.args4j.CmdLineParser;
22+
import org.kohsuke.args4j.Option;
23+
24+
import org.slf4j.Logger;
25+
import org.slf4j.LoggerFactory;
26+
27+
import org.apache.mxnet.javaapi.*;
28+
import org.apache.mxnet.infer.javaapi.ObjectDetector;
29+
30+
// scalastyle:off
31+
import java.awt.image.BufferedImage;
32+
// scalastyle:on
33+
34+
import java.util.ArrayList;
35+
import java.util.Arrays;
36+
import java.util.List;
37+
38+
import java.io.File;
39+
40+
public class SSDClassifierExample {
41+
@Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model")
42+
private String modelPathPrefix = "/model/ssd_resnet50_512";
43+
@Option(name = "--input-image", usage = "the input image")
44+
private String inputImagePath = "/images/dog.jpg";
45+
@Option(name = "--input-dir", usage = "the input batch of images directory")
46+
private String inputImageDir = "/images/";
47+
48+
final static Logger logger = LoggerFactory.getLogger(SSDClassifierExample.class);
49+
50+
static List<List<ObjectDetectorOutput>>
51+
runObjectDetectionSingle(String modelPathPrefix, String inputImagePath, List<Context> context) {
52+
Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
53+
List<DataDesc> inputDescriptors = new ArrayList<DataDesc>();
54+
inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
55+
BufferedImage img = ObjectDetector.loadImageFromFile(inputImagePath);
56+
ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
57+
return objDet.imageObjectDetect(img, 3);
58+
}
59+
60+
static List<List<List<ObjectDetectorOutput>>>
61+
runObjectDetectionBatch(String modelPathPrefix, String inputImageDir, List<Context> context) {
62+
Shape inputShape = new Shape(new int[]{1, 3, 512, 512});
63+
List<DataDesc> inputDescriptors = new ArrayList<DataDesc>();
64+
inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
65+
ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
66+
67+
// Loading batch of images from the directory path
68+
List<List<String>> batchFiles = generateBatches(inputImageDir, 20);
69+
List<List<List<ObjectDetectorOutput>>> outputList
70+
= new ArrayList<List<List<ObjectDetectorOutput>>>();
71+
72+
for (List<String> batchFile : batchFiles) {
73+
List<BufferedImage> imgList = ObjectDetector.loadInputBatch(batchFile);
74+
// Running inference on batch of images loaded in previous step
75+
List<List<ObjectDetectorOutput>> tmp
76+
= objDet.imageBatchObjectDetect(imgList, 5);
77+
outputList.add(tmp);
78+
}
79+
return outputList;
80+
}
81+
82+
static List<List<String>> generateBatches(String inputImageDirPath, int batchSize) {
83+
File dir = new File(inputImageDirPath);
84+
85+
List<List<String>> output = new ArrayList<List<String>>();
86+
List<String> batch = new ArrayList<String>();
87+
for (File imgFile : dir.listFiles()) {
88+
batch.add(imgFile.getPath());
89+
if (batch.size() == batchSize) {
90+
output.add(batch);
91+
batch = new ArrayList<String>();
92+
}
93+
}
94+
if (batch.size() > 0) {
95+
output.add(batch);
96+
}
97+
return output;
98+
}
99+
100+
public static void main(String[] args) {
101+
SSDClassifierExample inst = new SSDClassifierExample();
102+
CmdLineParser parser = new CmdLineParser(inst);
103+
try {
104+
parser.parseArgument(args);
105+
} catch (Exception e) {
106+
logger.error(e.getMessage(), e);
107+
parser.printUsage(System.err);
108+
System.exit(1);
109+
}
110+
111+
String mdprefixDir = inst.modelPathPrefix;
112+
String imgPath = inst.inputImagePath;
113+
String imgDir = inst.inputImageDir;
114+
115+
if (!checkExist(Arrays.asList(mdprefixDir + "-symbol.json", imgDir, imgPath))) {
116+
logger.error("Model or input image path does not exist");
117+
System.exit(1);
118+
}
119+
120+
List<Context> context = new ArrayList<Context>();
121+
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
122+
Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
123+
context.add(Context.gpu());
124+
} else {
125+
context.add(Context.cpu());
126+
}
127+
128+
try {
129+
Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
130+
Shape outputShape = new Shape(new int[] {1, 6132, 6});
131+
132+
133+
int width = inputShape.get(2);
134+
int height = inputShape.get(3);
135+
String outputStr = "\n";
136+
137+
List<List<ObjectDetectorOutput>> output
138+
= runObjectDetectionSingle(mdprefixDir, imgPath, context);
139+
140+
for (List<ObjectDetectorOutput> ele : output) {
141+
for (ObjectDetectorOutput i : ele) {
142+
outputStr += "Class: " + i.getClassName() + "\n";
143+
outputStr += "Probabilties: " + i.getProbability() + "\n";
144+
145+
List<Float> coord = Arrays.asList(i.getXMin() * width,
146+
i.getXMax() * height, i.getYMin() * width, i.getYMax() * height);
147+
StringBuilder sb = new StringBuilder();
148+
for (float c: coord) {
149+
sb.append(", ").append(c);
150+
}
151+
outputStr += "Coord:" + sb.substring(2)+ "\n";
152+
}
153+
}
154+
logger.info(outputStr);
155+
156+
List<List<List<ObjectDetectorOutput>>> outputList =
157+
runObjectDetectionBatch(mdprefixDir, imgDir, context);
158+
159+
outputStr = "\n";
160+
int index = 0;
161+
for (List<List<ObjectDetectorOutput>> i: outputList) {
162+
for (List<ObjectDetectorOutput> j : i) {
163+
outputStr += "*** Image " + (index + 1) + "***" + "\n";
164+
for (ObjectDetectorOutput k : j) {
165+
outputStr += "Class: " + k.getClassName() + "\n";
166+
outputStr += "Probabilties: " + k.getProbability() + "\n";
167+
List<Float> coord = Arrays.asList(k.getXMin() * width,
168+
k.getXMax() * height, k.getYMin() * width, k.getYMax() * height);
169+
170+
StringBuilder sb = new StringBuilder();
171+
for (float c : coord) {
172+
sb.append(", ").append(c);
173+
}
174+
outputStr += "Coord:" + sb.substring(2) + "\n";
175+
}
176+
index++;
177+
}
178+
}
179+
logger.info(outputStr);
180+
181+
} catch (Exception e) {
182+
logger.error(e.getMessage(), e);
183+
parser.printUsage(System.err);
184+
System.exit(1);
185+
}
186+
System.exit(0);
187+
}
188+
189+
static Boolean checkExist(List<String> arr) {
190+
Boolean exist = true;
191+
for (String item : arr) {
192+
exist = new File(item).exists() && exist;
193+
if (!exist) {
194+
logger.error("Cannot find: " + item);
195+
}
196+
}
197+
return exist;
198+
}
199+
}

0 commit comments

Comments
 (0)