Skip to content

Commit a6ded8c

Browse files
committed
[pytorch] Add BigGAN demo
1 parent f145614 commit a6ded8c

File tree

8 files changed

+1394
-2
lines changed

8 files changed

+1394
-2
lines changed

api/src/main/java/ai/djl/ndarray/BaseNDManager.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import ai.djl.ndarray.types.DataType;
1717
import ai.djl.ndarray.types.Shape;
1818
import ai.djl.util.PairList;
19+
import ai.djl.util.RandomUtils;
1920
import java.nio.Buffer;
2021
import java.nio.file.Path;
2122
import java.util.UUID;
@@ -156,7 +157,19 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy
156157
/** {@inheritDoc} */
157158
@Override
158159
public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) {
159-
throw new UnsupportedOperationException("Not supported!");
160+
int sampleSize = (int) shape.size();
161+
double[] dist = new double[sampleSize];
162+
163+
for (int i = 0; i < sampleSize; i++) {
164+
double sample = RandomUtils.nextGaussian();
165+
while (sample < -2 || sample > 2) {
166+
sample = RandomUtils.nextGaussian();
167+
}
168+
169+
dist[i] = sample;
170+
}
171+
172+
return create(dist).addi(loc).muli(scale).reshape(shape).toType(dataType, false);
160173
}
161174

162175
/** {@inheritDoc} */

examples/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ dependencies {
4444
}
4545

4646
application {
47-
mainClassName = System.getProperty("main", "ai.djl.examples.inference.ObjectDetection")
47+
mainClassName = System.getProperty("main", "ai.djl.examples.inference.biggan.Generator")
4848
}
4949

5050
run {
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.examples.inference.biggan;
14+
15+
import java.io.IOException;
16+
import java.nio.file.Files;
17+
import java.nio.file.Paths;
18+
import java.util.ArrayList;
19+
import java.util.List;
20+
import java.util.Map;
21+
import java.util.concurrent.ConcurrentHashMap;
22+
import org.slf4j.Logger;
23+
import org.slf4j.LoggerFactory;
24+
25+
public final class BigGANCategory {
26+
private static final Logger logger = LoggerFactory.getLogger(BigGANCategory.class);
27+
28+
public static final int NUMBER_OF_CATEGORIES = 1000;
29+
private static final Map<String, BigGANCategory> CATEGORIES_BY_NAME =
30+
new ConcurrentHashMap<>(NUMBER_OF_CATEGORIES);
31+
private static String[] categoriesById;
32+
33+
private int id;
34+
private String[] names;
35+
36+
static {
37+
try {
38+
parseCategories();
39+
} catch (IOException e) {
40+
logger.error("Error parsing the ImageNet categories: {}", e);
41+
}
42+
createCategoriesByName();
43+
}
44+
45+
private BigGANCategory(int id, String[] names) {
46+
this.id = id;
47+
this.names = names;
48+
}
49+
50+
public int getId() {
51+
return id;
52+
}
53+
54+
public String[] getNames() {
55+
return names.clone();
56+
}
57+
58+
public static BigGANCategory id(int id) {
59+
String names = categoriesById[id];
60+
int index = names.indexOf(',');
61+
if (index < 0) {
62+
return of(names);
63+
} else {
64+
return of(names.substring(0, index));
65+
}
66+
}
67+
68+
public static BigGANCategory of(String name) {
69+
if (!CATEGORIES_BY_NAME.containsKey(name)) {
70+
throw new IllegalArgumentException(name + " is not a valid category.");
71+
}
72+
return CATEGORIES_BY_NAME.get(name);
73+
}
74+
75+
private static void createCategoriesByName() {
76+
for (int i = 0; i < NUMBER_OF_CATEGORIES; i++) {
77+
String[] categoryNames = categoriesById[i].split(", ");
78+
BigGANCategory category = new BigGANCategory(i, categoryNames);
79+
80+
for (String name : categoryNames) {
81+
CATEGORIES_BY_NAME.put(name, category);
82+
}
83+
}
84+
}
85+
86+
private static void parseCategories() throws IOException {
87+
String filePath = "src/main/resources/categories.txt";
88+
89+
List<String> fileLines = Files.readAllLines(Paths.get(filePath));
90+
List<String> categories = new ArrayList<>(NUMBER_OF_CATEGORIES);
91+
for (String line : fileLines) {
92+
int nameIndex = line.indexOf(':') + 2;
93+
categories.add(line.substring(nameIndex));
94+
}
95+
96+
categoriesById = categories.toArray(new String[] {});
97+
}
98+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.examples.inference.biggan;
14+
15+
public final class BigGANInput {
16+
private int sampleSize;
17+
private float truncation;
18+
private BigGANCategory category;
19+
20+
public BigGANInput(int sampleSize, float truncation, BigGANCategory category) {
21+
this.sampleSize = sampleSize;
22+
this.truncation = truncation;
23+
this.category = category;
24+
}
25+
26+
BigGANInput(Builder builder) {
27+
this.sampleSize = builder.sampleSize;
28+
this.truncation = builder.truncation;
29+
this.category = builder.category;
30+
}
31+
32+
public int getSampleSize() {
33+
return sampleSize;
34+
}
35+
36+
public float getTruncation() {
37+
return truncation;
38+
}
39+
40+
public BigGANCategory getCategory() {
41+
return category;
42+
}
43+
44+
public static Builder builder() {
45+
return new Builder();
46+
}
47+
48+
public static final class Builder {
49+
private int sampleSize = 5;
50+
private float truncation = 0.5f;
51+
private BigGANCategory category;
52+
53+
Builder() {
54+
category = BigGANCategory.of("Egyptian cat");
55+
}
56+
57+
public Builder optSampleSize(int sampleSize) {
58+
this.sampleSize = sampleSize;
59+
return this;
60+
}
61+
62+
public Builder optTruncation(float truncation) {
63+
this.truncation = truncation;
64+
return this;
65+
}
66+
67+
public Builder setCategory(BigGANCategory category) {
68+
this.category = category;
69+
return this;
70+
}
71+
72+
public BigGANInput build() {
73+
return new BigGANInput(this);
74+
}
75+
}
76+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.examples.inference.biggan;
14+
15+
import ai.djl.engine.Engine;
16+
import ai.djl.modality.cv.Image;
17+
import ai.djl.modality.cv.ImageFactory;
18+
import ai.djl.ndarray.NDArray;
19+
import ai.djl.ndarray.NDList;
20+
import ai.djl.ndarray.NDManager;
21+
import ai.djl.ndarray.types.DataType;
22+
import ai.djl.ndarray.types.Shape;
23+
import ai.djl.translate.Batchifier;
24+
import ai.djl.translate.Translator;
25+
import ai.djl.translate.TranslatorContext;
26+
import org.slf4j.Logger;
27+
import org.slf4j.LoggerFactory;
28+
29+
final class BigGANTranslator implements Translator<BigGANInput, Image[]> {
30+
private static final Logger logger = LoggerFactory.getLogger(BigGANTranslator.class);
31+
private static final int SEED_COLUMN_SIZE = 128;
32+
33+
@Override
34+
public Image[] processOutput(TranslatorContext ctx, NDList list) throws Exception {
35+
logOutputList(list);
36+
37+
NDArray output = list.get(0).addi(1).muli(128).clip(0, 255).toType(DataType.UINT8, false);
38+
39+
int sampleSize = (int) output.getShape().get(0);
40+
Image[] images = new Image[sampleSize];
41+
42+
for (int i = 0; i < sampleSize; i++) {
43+
images[i] = ImageFactory.getInstance().fromNDArray(output.get(i));
44+
}
45+
46+
return images;
47+
}
48+
49+
private void logOutputList(NDList list) {
50+
logger.info("");
51+
logger.info("MY OUTPUT:");
52+
list.forEach(array -> logger.info(" out: {}", array.getShape()));
53+
}
54+
55+
@Override
56+
public NDList processInput(TranslatorContext ctx, BigGANInput input) throws Exception {
57+
Engine.getInstance().setRandomSeed(0);
58+
NDManager manager = ctx.getNDManager();
59+
60+
NDArray categoryArray = createCategoryArray(manager, input);
61+
NDArray seed =
62+
manager.truncatedNormal(new Shape(input.getSampleSize(), SEED_COLUMN_SIZE))
63+
.muli(input.getTruncation());
64+
NDArray truncation = manager.create(input.getTruncation());
65+
66+
logInputArrays(categoryArray, seed, truncation);
67+
return new NDList(seed, categoryArray, truncation);
68+
}
69+
70+
private NDArray createCategoryArray(NDManager manager, BigGANInput input) {
71+
int categoryId = input.getCategory().getId();
72+
int sampleSize = input.getSampleSize();
73+
74+
int[] indices = new int[sampleSize];
75+
for (int i = 0; i < sampleSize; i++) {
76+
indices[i] = categoryId;
77+
}
78+
return manager.create(indices).oneHot(BigGANCategory.NUMBER_OF_CATEGORIES);
79+
}
80+
81+
private void logInputArrays(NDArray categoryArray, NDArray seed, NDArray truncation) {
82+
logger.info("");
83+
logger.info("MY INPUTS: ");
84+
logger.info(" y: {}", categoryArray.getShape());
85+
logger.info(" z: {}", seed.get(":, :10"));
86+
logger.info(" truncation: {}", truncation.getShape());
87+
}
88+
89+
@Override
90+
public Batchifier getBatchifier() {
91+
return null;
92+
}
93+
}

0 commit comments

Comments
 (0)