Skip to content

Commit 5d8f8ce

Browse files
committed
Refactor BlockFactory interface
Change-Id: I7e9aa60f541c00852c548332338ee0fc914ee92f
1 parent 5f85fcc commit 5d8f8ce

File tree

19 files changed

+214
-291
lines changed

19 files changed

+214
-291
lines changed

api/src/main/java/ai/djl/BaseModel.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@
1919
import ai.djl.ndarray.types.DataType;
2020
import ai.djl.ndarray.types.Shape;
2121
import ai.djl.nn.Block;
22-
import ai.djl.nn.BlockFactory;
2322
import ai.djl.nn.SymbolBlock;
2423
import ai.djl.training.ParameterStore;
2524
import ai.djl.training.Trainer;
2625
import ai.djl.training.TrainingConfig;
2726
import ai.djl.translate.Translator;
28-
import ai.djl.util.ClassLoaderUtils;
2927
import ai.djl.util.Pair;
3028
import ai.djl.util.PairList;
3129
import ai.djl.util.Utils;
@@ -217,14 +215,6 @@ protected void setModelDir(Path modelDir) {
217215
this.modelDir = modelDir.toAbsolutePath();
218216
}
219217

220-
protected Block loadFromBlockFactory() {
221-
BlockFactory factory = ClassLoaderUtils.findImplementation(modelDir, null);
222-
if (factory == null) {
223-
return null;
224-
}
225-
return factory.newBlock(manager);
226-
}
227-
228218
/** {@inheritDoc} */
229219
@Override
230220
public void save(Path modelPath, String newModelName) throws IOException {

api/src/main/java/ai/djl/nn/BlockFactory.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
*/
1313
package ai.djl.nn;
1414

15-
import ai.djl.ndarray.NDManager;
15+
import ai.djl.Model;
1616
import ai.djl.repository.zoo.ModelZoo;
17+
import java.io.IOException;
1718
import java.io.Serializable;
19+
import java.nio.file.Path;
20+
import java.util.Map;
1821

1922
/**
2023
* Block factory is a component to make standard for block creating and saving procedure. Block
@@ -27,8 +30,11 @@ public interface BlockFactory extends Serializable {
2730
/**
2831
* Constructs the uninitialized block.
2932
*
30-
* @param manager the manager to assign to block
33+
* @param model the model of the block
34+
* @param modelPath the directory of the model location
35+
* @param arguments the block creation arguments
3136
* @return the uninitialized block
37+
* @throws IOException if IO operation fails during creating block
3238
*/
33-
Block newBlock(NDManager manager);
39+
Block newBlock(Model model, Path modelPath, Map<String, ?> arguments) throws IOException;
3440
}

api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import ai.djl.modality.cv.Image;
2424
import ai.djl.modality.cv.translator.ImageClassificationTranslatorFactory;
2525
import ai.djl.ndarray.NDList;
26+
import ai.djl.nn.Block;
27+
import ai.djl.nn.BlockFactory;
2628
import ai.djl.repository.Artifact;
2729
import ai.djl.repository.MRL;
2830
import ai.djl.repository.Repository;
@@ -32,6 +34,7 @@
3234
import ai.djl.translate.TranslateException;
3335
import ai.djl.translate.Translator;
3436
import ai.djl.translate.TranslatorFactory;
37+
import ai.djl.util.ClassLoaderUtils;
3538
import ai.djl.util.Pair;
3639
import ai.djl.util.Progress;
3740
import java.io.IOException;
@@ -155,10 +158,14 @@ public <I, O> ZooModel<I, O> loadModel(Criteria<I, O> criteria)
155158
modelName = artifact.getName();
156159
}
157160

158-
Model model = createModel(modelName, criteria.getDevice(), artifact, arguments, engine);
159-
if (criteria.getBlock() != null) {
160-
model.setBlock(criteria.getBlock());
161-
}
161+
Model model =
162+
createModel(
163+
modelPath,
164+
modelName,
165+
criteria.getDevice(),
166+
criteria.getBlock(),
167+
arguments,
168+
engine);
162169
model.load(modelPath, null, options);
163170
Translator<I, O> translator = factory.newInstance(model, arguments);
164171
return new ZooModel<>(model, translator);
@@ -182,13 +189,23 @@ public List<Artifact> listModels() throws IOException {
182189
}
183190

184191
protected Model createModel(
192+
Path modelPath,
185193
String name,
186194
Device device,
187-
Artifact artifact,
195+
Block block,
188196
Map<String, Object> arguments,
189197
String engine)
190198
throws IOException {
191-
return Model.newInstance(name, device, engine);
199+
Model model = Model.newInstance(name, device, engine);
200+
if (block == null) {
201+
String className = (String) arguments.get("blockFactory");
202+
BlockFactory factory = ClassLoaderUtils.findImplementation(modelPath, className);
203+
if (factory != null) {
204+
block = factory.newBlock(model, modelPath, arguments);
205+
}
206+
}
207+
model.setBlock(block);
208+
return model;
192209
}
193210

194211
/** {@inheritDoc} */

extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import ai.djl.fasttext.FtModel;
2020
import ai.djl.fasttext.zoo.FtModelZoo;
2121
import ai.djl.modality.Classifications;
22-
import ai.djl.repository.Artifact;
22+
import ai.djl.nn.Block;
2323
import ai.djl.repository.MRL;
2424
import ai.djl.repository.Repository;
2525
import ai.djl.repository.zoo.BaseModelLoader;
@@ -30,6 +30,7 @@
3030
import ai.djl.translate.TranslatorFactory;
3131
import ai.djl.util.Pair;
3232
import java.io.IOException;
33+
import java.nio.file.Path;
3334
import java.util.Map;
3435

3536
/** Model loader for fastText cooking stackexchange models. */
@@ -68,9 +69,10 @@ public ZooModel<String, Classifications> loadModel()
6869
/** {@inheritDoc} */
6970
@Override
7071
protected Model createModel(
72+
Path modelPath,
7173
String name,
7274
Device device,
73-
Artifact artifact,
75+
Block block,
7476
Map<String, Object> arguments,
7577
String engine) {
7678
return new FtModel(name);

integration/src/main/java/ai/djl/integration/tests/nn/BlockFactoryTest.java

Lines changed: 5 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,59 +30,21 @@
3030
import ai.djl.repository.zoo.Criteria;
3131
import ai.djl.repository.zoo.ModelNotFoundException;
3232
import ai.djl.repository.zoo.ZooModel;
33-
import ai.djl.testing.Assertions;
3433
import ai.djl.training.ParameterStore;
3534
import ai.djl.training.util.ProgressBar;
36-
import ai.djl.translate.NoopTranslator;
3735
import ai.djl.translate.TranslateException;
3836
import ai.djl.util.Utils;
3937
import ai.djl.util.ZipUtils;
40-
import java.io.ByteArrayInputStream;
41-
import java.io.ByteArrayOutputStream;
42-
import java.io.DataInputStream;
43-
import java.io.DataOutputStream;
4438
import java.io.IOException;
4539
import java.nio.file.Files;
4640
import java.nio.file.Path;
4741
import java.nio.file.Paths;
42+
import java.util.Map;
4843
import org.testng.Assert;
4944
import org.testng.annotations.Test;
5045

5146
public class BlockFactoryTest {
5247

53-
@Test
54-
public void testBlockLoadingSaving()
55-
throws IOException, ModelNotFoundException, MalformedModelException,
56-
TranslateException {
57-
TestBlockFactory factory = new TestBlockFactory();
58-
Model model = factory.getRemoveLastBlockModel();
59-
try (NDManager manager = NDManager.newBaseManager()) {
60-
Block block = model.getBlock();
61-
block.forward(
62-
new ParameterStore(manager, true),
63-
new NDList(manager.ones(new Shape(1, 3, 32, 32))),
64-
true);
65-
ByteArrayOutputStream os = new ByteArrayOutputStream();
66-
block.saveParameters(new DataOutputStream(os));
67-
ByteArrayInputStream bis = new ByteArrayInputStream(os.toByteArray());
68-
Block newBlock = factory.newBlock(manager);
69-
newBlock.loadParameters(manager, new DataInputStream(bis));
70-
try (Model test = Model.newInstance("test")) {
71-
test.setBlock(newBlock);
72-
try (Predictor<NDList, NDList> predOrigin =
73-
model.newPredictor(new NoopTranslator());
74-
Predictor<NDList, NDList> predDest =
75-
test.newPredictor(new NoopTranslator())) {
76-
NDList input = new NDList(manager.ones(new Shape(1, 3, 32, 32)));
77-
NDList originOut = predOrigin.predict(input);
78-
NDList destOut = predDest.predict(input);
79-
Assertions.assertAlmostEquals(originOut, destOut);
80-
}
81-
}
82-
}
83-
model.close();
84-
}
85-
8648
@Test
8749
public void testBlockFactoryLoadingFromZip()
8850
throws MalformedModelException, ModelNotFoundException, IOException,
@@ -97,9 +59,9 @@ public void testBlockFactoryLoadingFromZip()
9759
.optModelPath(zipPath)
9860
.optModelName("exported")
9961
.build();
100-
try (NDManager manager = NDManager.newBaseManager();
101-
ZooModel<NDList, NDList> model = criteria.loadModel();
62+
try (ZooModel<NDList, NDList> model = criteria.loadModel();
10263
Predictor<NDList, NDList> pred = model.newPredictor()) {
64+
NDManager manager = model.getNDManager();
10365
NDList destOut = pred.predict(new NDList(manager.ones(new Shape(1, 3, 32, 32))));
10466
Assert.assertEquals(destOut.singletonOrThrow().getShape(), new Shape(1, 10));
10567
}
@@ -136,9 +98,9 @@ public static class TestBlockFactory implements BlockFactory {
13698
private static final long serialVersionUID = 1234567L;
13799

138100
@Override
139-
public Block newBlock(NDManager manager) {
101+
public Block newBlock(Model model, Path modelPath, Map<String, ?> arguments) {
140102
SequentialBlock newBlock = new SequentialBlock();
141-
newBlock.add(SymbolBlock.newInstance(manager));
103+
newBlock.add(SymbolBlock.newInstance(model.getNDManager()));
142104
newBlock.add(Linear.builder().setUnits(10).build());
143105
return newBlock;
144106
}

model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
*/
1313
package ai.djl.basicmodelzoo;
1414

15-
import ai.djl.basicmodelzoo.cv.classification.MlpModelLoader;
16-
import ai.djl.basicmodelzoo.cv.classification.ResNetModelLoader;
17-
import ai.djl.basicmodelzoo.cv.object_detection.ssd.SsdModelLoader;
15+
import ai.djl.modality.cv.zoo.ImageClassificationModelLoader;
16+
import ai.djl.modality.cv.zoo.ObjectDetectionModelLoader;
1817
import ai.djl.repository.Repository;
18+
import ai.djl.repository.zoo.ModelLoader;
1919
import ai.djl.repository.zoo.ModelZoo;
2020
import java.util.HashSet;
2121
import java.util.Set;
@@ -25,11 +25,15 @@ public class BasicModelZoo implements ModelZoo {
2525

2626
private static final String REPO_URL = "https://mlrepo.djl.ai/";
2727
private static final Repository REPOSITORY = Repository.newInstance("zoo", REPO_URL);
28+
private static final ModelZoo ZOO = new BasicModelZoo();
2829
public static final String GROUP_ID = "ai.djl.zoo";
2930

30-
public static final ResNetModelLoader RESNET = new ResNetModelLoader(REPOSITORY);
31-
public static final MlpModelLoader MLP = new MlpModelLoader(REPOSITORY);
32-
public static final SsdModelLoader SSD = new SsdModelLoader(REPOSITORY);
31+
public static final ModelLoader RESNET =
32+
new ImageClassificationModelLoader(REPOSITORY, GROUP_ID, "resnet", "0.0.2", ZOO);
33+
public static final ModelLoader MLP =
34+
new ImageClassificationModelLoader(REPOSITORY, GROUP_ID, "mlp", "0.0.3", ZOO);
35+
public static final ModelLoader SSD =
36+
new ObjectDetectionModelLoader(REPOSITORY, GROUP_ID, "ssd", "0.0.2", ZOO);
3337

3438
/** {@inheritDoc} */
3539
@Override
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.basicmodelzoo.basic;
14+
15+
import ai.djl.Model;
16+
import ai.djl.nn.Block;
17+
import ai.djl.nn.BlockFactory;
18+
import java.nio.file.Path;
19+
import java.util.List;
20+
import java.util.Map;
21+
22+
/** A {@link BlockFactory} class that creates MLP block. */
23+
public class MlpBlockFactory implements BlockFactory {
24+
25+
private static final long serialVersionUID = 1L;
26+
27+
/** {@inheritDoc} */
28+
@Override
29+
@SuppressWarnings("unchecked")
30+
public Block newBlock(Model model, Path modelPath, Map<String, ?> arguments) {
31+
Double width = (Double) arguments.get("width");
32+
if (width == null) {
33+
width = 28d;
34+
}
35+
Double height = (Double) arguments.get("height");
36+
if (height == null) {
37+
height = 28d;
38+
}
39+
int input = width.intValue() * height.intValue();
40+
int output = ((Double) arguments.get("output")).intValue();
41+
int[] hidden =
42+
((List<Double>) arguments.get("hidden"))
43+
.stream()
44+
.mapToInt(Double::intValue)
45+
.toArray();
46+
47+
return new Mlp(input, output, hidden);
48+
}
49+
}

model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/MlpModelLoader.java

Lines changed: 0 additions & 64 deletions
This file was deleted.

0 commit comments

Comments
 (0)