Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions api/src/main/java/ai/djl/BaseModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.BlockFactory;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.translate.Translator;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Utils;
Expand Down Expand Up @@ -217,14 +215,6 @@ protected void setModelDir(Path modelDir) {
this.modelDir = modelDir.toAbsolutePath();
}

protected Block loadFromBlockFactory() {
BlockFactory factory = ClassLoaderUtils.findImplementation(modelDir, null);
if (factory == null) {
return null;
}
return factory.newBlock(manager);
}

/** {@inheritDoc} */
@Override
public void save(Path modelPath, String newModelName) throws IOException {
Expand Down
12 changes: 9 additions & 3 deletions api/src/main/java/ai/djl/nn/BlockFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
*/
package ai.djl.nn;

import ai.djl.ndarray.NDManager;
import ai.djl.Model;
import ai.djl.repository.zoo.ModelZoo;
import java.io.IOException;
import java.io.Serializable;
import java.nio.file.Path;
import java.util.Map;

/**
* Block factory is a component to make standard for block creating and saving procedure. Block
Expand All @@ -27,8 +30,11 @@ public interface BlockFactory extends Serializable {
/**
* Constructs the uninitialized block.
*
* @param manager the manager to assign to block
* @param model the model of the block
* @param modelPath the directory of the model location
* @param arguments the block creation arguments
* @return the uninitialized block
* @throws IOException if IO operation fails during creating block
*/
Block newBlock(NDManager manager);
Block newBlock(Model model, Path modelPath, Map<String, ?> arguments) throws IOException;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this right? I thought the plan was to use BlockFactory to deprecate the Map<String, ?> arguments and replace it with serializing the BlockFactory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of BlockFactory is to create the Block then you can load the parameters into the Block.
The problem is the BlockFactory may need extra information:

  1. Some files in the model directory
  2. some model specific arguments. If we hard-code those arguments, which will make the BlockFactory implementation not re-useable.

}
31 changes: 25 additions & 6 deletions api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.translator.ImageClassificationTranslatorFactory;
import ai.djl.ndarray.NDList;
import ai.djl.nn.Block;
import ai.djl.nn.BlockFactory;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
Expand All @@ -32,6 +34,7 @@
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import java.io.IOException;
Expand Down Expand Up @@ -155,10 +158,14 @@ public <I, O> ZooModel<I, O> loadModel(Criteria<I, O> criteria)
modelName = artifact.getName();
}

Model model = createModel(modelName, criteria.getDevice(), artifact, arguments, engine);
if (criteria.getBlock() != null) {
model.setBlock(criteria.getBlock());
}
Model model =
createModel(
modelPath,
modelName,
criteria.getDevice(),
criteria.getBlock(),
arguments,
engine);
model.load(modelPath, null, options);
Translator<I, O> translator = factory.newInstance(model, arguments);
return new ZooModel<>(model, translator);
Expand All @@ -182,13 +189,25 @@ public List<Artifact> listModels() throws IOException {
}

protected Model createModel(
Path modelPath,
String name,
Device device,
Artifact artifact,
Block block,
Map<String, Object> arguments,
String engine)
throws IOException {
return Model.newInstance(name, device, engine);
Model model = Model.newInstance(name, device, engine);
if (block == null) {
String className = (String) arguments.get("blockFactory");
BlockFactory factory = ClassLoaderUtils.findImplementation(modelPath, className);
if (factory != null) {
block = factory.newBlock(model, modelPath, arguments);
}
}
if (block != null) {
model.setBlock(block);
}
return model;
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import ai.djl.fasttext.FtModel;
import ai.djl.fasttext.zoo.FtModelZoo;
import ai.djl.modality.Classifications;
import ai.djl.repository.Artifact;
import ai.djl.nn.Block;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.BaseModelLoader;
Expand All @@ -30,6 +30,7 @@
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Map;

/** Model loader for fastText cooking stackexchange models. */
Expand Down Expand Up @@ -68,9 +69,10 @@ public ZooModel<String, Classifications> loadModel()
/** {@inheritDoc} */
@Override
protected Model createModel(
Path modelPath,
String name,
Device device,
Artifact artifact,
Block block,
Map<String, Object> arguments,
String engine) {
return new FtModel(name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,59 +30,21 @@
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.testing.Assertions;
import ai.djl.training.ParameterStore;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.util.Utils;
import ai.djl.util.ZipUtils;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import org.testng.Assert;
import org.testng.annotations.Test;

public class BlockFactoryTest {

@Test
public void testBlockLoadingSaving()
throws IOException, ModelNotFoundException, MalformedModelException,
TranslateException {
TestBlockFactory factory = new TestBlockFactory();
Model model = factory.getRemoveLastBlockModel();
try (NDManager manager = NDManager.newBaseManager()) {
Block block = model.getBlock();
block.forward(
new ParameterStore(manager, true),
new NDList(manager.ones(new Shape(1, 3, 32, 32))),
true);
ByteArrayOutputStream os = new ByteArrayOutputStream();
block.saveParameters(new DataOutputStream(os));
ByteArrayInputStream bis = new ByteArrayInputStream(os.toByteArray());
Block newBlock = factory.newBlock(manager);
newBlock.loadParameters(manager, new DataInputStream(bis));
try (Model test = Model.newInstance("test")) {
test.setBlock(newBlock);
try (Predictor<NDList, NDList> predOrigin =
model.newPredictor(new NoopTranslator());
Predictor<NDList, NDList> predDest =
test.newPredictor(new NoopTranslator())) {
NDList input = new NDList(manager.ones(new Shape(1, 3, 32, 32)));
NDList originOut = predOrigin.predict(input);
NDList destOut = predDest.predict(input);
Assertions.assertAlmostEquals(originOut, destOut);
}
}
}
model.close();
}

@Test
public void testBlockFactoryLoadingFromZip()
throws MalformedModelException, ModelNotFoundException, IOException,
Expand All @@ -97,9 +59,9 @@ public void testBlockFactoryLoadingFromZip()
.optModelPath(zipPath)
.optModelName("exported")
.build();
try (NDManager manager = NDManager.newBaseManager();
ZooModel<NDList, NDList> model = criteria.loadModel();
try (ZooModel<NDList, NDList> model = criteria.loadModel();
Predictor<NDList, NDList> pred = model.newPredictor()) {
NDManager manager = model.getNDManager();
NDList destOut = pred.predict(new NDList(manager.ones(new Shape(1, 3, 32, 32))));
Assert.assertEquals(destOut.singletonOrThrow().getShape(), new Shape(1, 10));
}
Expand Down Expand Up @@ -136,9 +98,9 @@ public static class TestBlockFactory implements BlockFactory {
private static final long serialVersionUID = 1234567L;

@Override
public Block newBlock(NDManager manager) {
public Block newBlock(Model model, Path modelPath, Map<String, ?> arguments) {
SequentialBlock newBlock = new SequentialBlock();
newBlock.add(SymbolBlock.newInstance(manager));
newBlock.add(SymbolBlock.newInstance(model.getNDManager()));
newBlock.add(Linear.builder().setUnits(10).build());
return newBlock;
}
Expand Down
16 changes: 10 additions & 6 deletions model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
*/
package ai.djl.basicmodelzoo;

import ai.djl.basicmodelzoo.cv.classification.MlpModelLoader;
import ai.djl.basicmodelzoo.cv.classification.ResNetModelLoader;
import ai.djl.basicmodelzoo.cv.object_detection.ssd.SsdModelLoader;
import ai.djl.modality.cv.zoo.ImageClassificationModelLoader;
import ai.djl.modality.cv.zoo.ObjectDetectionModelLoader;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.ModelLoader;
import ai.djl.repository.zoo.ModelZoo;
import java.util.HashSet;
import java.util.Set;
Expand All @@ -25,11 +25,15 @@ public class BasicModelZoo implements ModelZoo {

private static final String REPO_URL = "https://mlrepo.djl.ai/";
private static final Repository REPOSITORY = Repository.newInstance("zoo", REPO_URL);
private static final ModelZoo ZOO = new BasicModelZoo();
public static final String GROUP_ID = "ai.djl.zoo";

public static final ResNetModelLoader RESNET = new ResNetModelLoader(REPOSITORY);
public static final MlpModelLoader MLP = new MlpModelLoader(REPOSITORY);
public static final SsdModelLoader SSD = new SsdModelLoader(REPOSITORY);
public static final ModelLoader RESNET =
new ImageClassificationModelLoader(REPOSITORY, GROUP_ID, "resnet", "0.0.2", ZOO);
public static final ModelLoader MLP =
new ImageClassificationModelLoader(REPOSITORY, GROUP_ID, "mlp", "0.0.3", ZOO);
public static final ModelLoader SSD =
new ObjectDetectionModelLoader(REPOSITORY, GROUP_ID, "ssd", "0.0.2", ZOO);

/** {@inheritDoc} */
@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.basicmodelzoo.basic;

import ai.djl.Model;
import ai.djl.nn.Block;
import ai.djl.nn.BlockFactory;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;

/** A {@link BlockFactory} class that creates MLP block. */
public class MlpBlockFactory implements BlockFactory {

private static final long serialVersionUID = 1L;

/** {@inheritDoc} */
@Override
@SuppressWarnings("unchecked")
public Block newBlock(Model model, Path modelPath, Map<String, ?> arguments) {
Double width = (Double) arguments.get("width");
if (width == null) {
width = 28d;
}
Double height = (Double) arguments.get("height");
if (height == null) {
height = 28d;
}
int input = width.intValue() * height.intValue();
int output = ((Double) arguments.get("output")).intValue();
int[] hidden =
((List<Double>) arguments.get("hidden"))
.stream()
.mapToInt(Double::intValue)
.toArray();

return new Mlp(input, output, hidden);
}
}

This file was deleted.

Loading