diff --git a/api/src/main/java/ai/djl/BaseModel.java b/api/src/main/java/ai/djl/BaseModel.java index 4b2478d9a80..a41246ab527 100644 --- a/api/src/main/java/ai/djl/BaseModel.java +++ b/api/src/main/java/ai/djl/BaseModel.java @@ -19,13 +19,11 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; -import ai.djl.nn.BlockFactory; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; import ai.djl.translate.Translator; -import ai.djl.util.ClassLoaderUtils; import ai.djl.util.Pair; import ai.djl.util.PairList; import ai.djl.util.Utils; @@ -217,14 +215,6 @@ protected void setModelDir(Path modelDir) { this.modelDir = modelDir.toAbsolutePath(); } - protected Block loadFromBlockFactory() { - BlockFactory factory = ClassLoaderUtils.findImplementation(modelDir, null); - if (factory == null) { - return null; - } - return factory.newBlock(manager); - } - /** {@inheritDoc} */ @Override public void save(Path modelPath, String newModelName) throws IOException { diff --git a/api/src/main/java/ai/djl/nn/BlockFactory.java b/api/src/main/java/ai/djl/nn/BlockFactory.java index c6747b0fe64..1e315c5ecca 100644 --- a/api/src/main/java/ai/djl/nn/BlockFactory.java +++ b/api/src/main/java/ai/djl/nn/BlockFactory.java @@ -12,9 +12,12 @@ */ package ai.djl.nn; -import ai.djl.ndarray.NDManager; +import ai.djl.Model; import ai.djl.repository.zoo.ModelZoo; +import java.io.IOException; import java.io.Serializable; +import java.nio.file.Path; +import java.util.Map; /** * Block factory is a component to make standard for block creating and saving procedure. Block @@ -27,8 +30,11 @@ public interface BlockFactory extends Serializable { /** * Constructs the uninitialized block. * - * @param manager the manager to assign to block + * @param model the model of the block + * @param modelPath the directory of the model location + * @param arguments the block creation arguments * @return the uninitialized block + * @throws IOException if IO operation fails during creating block */ - Block newBlock(NDManager manager); + Block newBlock(Model model, Path modelPath, Map arguments) throws IOException; } diff --git a/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java b/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java index e2ab30d7417..b692715d1eb 100644 --- a/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java +++ b/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java @@ -23,6 +23,8 @@ import ai.djl.modality.cv.Image; import ai.djl.modality.cv.translator.ImageClassificationTranslatorFactory; import ai.djl.ndarray.NDList; +import ai.djl.nn.Block; +import ai.djl.nn.BlockFactory; import ai.djl.repository.Artifact; import ai.djl.repository.MRL; import ai.djl.repository.Repository; @@ -32,6 +34,7 @@ import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorFactory; +import ai.djl.util.ClassLoaderUtils; import ai.djl.util.Pair; import ai.djl.util.Progress; import java.io.IOException; @@ -155,10 +158,14 @@ public ZooModel loadModel(Criteria criteria) modelName = artifact.getName(); } - Model model = createModel(modelName, criteria.getDevice(), artifact, arguments, engine); - if (criteria.getBlock() != null) { - model.setBlock(criteria.getBlock()); - } + Model model = + createModel( + modelPath, + modelName, + criteria.getDevice(), + criteria.getBlock(), + arguments, + engine); model.load(modelPath, null, options); Translator translator = factory.newInstance(model, arguments); return new ZooModel<>(model, translator); @@ -182,13 +189,25 @@ public List listModels() throws IOException { } protected Model createModel( + Path modelPath, String name, Device device, - Artifact artifact, + Block block, Map arguments, String engine) throws IOException { - return Model.newInstance(name, device, engine); + Model model = Model.newInstance(name, device, engine); + if (block == null) { + String className = (String) arguments.get("blockFactory"); + BlockFactory factory = ClassLoaderUtils.findImplementation(modelPath, className); + if (factory != null) { + block = factory.newBlock(model, modelPath, arguments); + } + } + if (block != null) { + model.setBlock(block); + } + return model; } /** {@inheritDoc} */ diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java index 19a6bb1c957..e45010daa21 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java @@ -19,7 +19,7 @@ import ai.djl.fasttext.FtModel; import ai.djl.fasttext.zoo.FtModelZoo; import ai.djl.modality.Classifications; -import ai.djl.repository.Artifact; +import ai.djl.nn.Block; import ai.djl.repository.MRL; import ai.djl.repository.Repository; import ai.djl.repository.zoo.BaseModelLoader; @@ -30,6 +30,7 @@ import ai.djl.translate.TranslatorFactory; import ai.djl.util.Pair; import java.io.IOException; +import java.nio.file.Path; import java.util.Map; /** Model loader for fastText cooking stackexchange models. */ @@ -68,9 +69,10 @@ public ZooModel loadModel() /** {@inheritDoc} */ @Override protected Model createModel( + Path modelPath, String name, Device device, - Artifact artifact, + Block block, Map arguments, String engine) { return new FtModel(name); diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/BlockFactoryTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/BlockFactoryTest.java index f25d9cdf37f..bf3e2bcf1ad 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/BlockFactoryTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/BlockFactoryTest.java @@ -30,59 +30,21 @@ import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ZooModel; -import ai.djl.testing.Assertions; import ai.djl.training.ParameterStore; import ai.djl.training.util.ProgressBar; -import ai.djl.translate.NoopTranslator; import ai.djl.translate.TranslateException; import ai.djl.util.Utils; import ai.djl.util.ZipUtils; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.DataInputStream; -import java.io.DataOutputStream; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Map; import org.testng.Assert; import org.testng.annotations.Test; public class BlockFactoryTest { - @Test - public void testBlockLoadingSaving() - throws IOException, ModelNotFoundException, MalformedModelException, - TranslateException { - TestBlockFactory factory = new TestBlockFactory(); - Model model = factory.getRemoveLastBlockModel(); - try (NDManager manager = NDManager.newBaseManager()) { - Block block = model.getBlock(); - block.forward( - new ParameterStore(manager, true), - new NDList(manager.ones(new Shape(1, 3, 32, 32))), - true); - ByteArrayOutputStream os = new ByteArrayOutputStream(); - block.saveParameters(new DataOutputStream(os)); - ByteArrayInputStream bis = new ByteArrayInputStream(os.toByteArray()); - Block newBlock = factory.newBlock(manager); - newBlock.loadParameters(manager, new DataInputStream(bis)); - try (Model test = Model.newInstance("test")) { - test.setBlock(newBlock); - try (Predictor predOrigin = - model.newPredictor(new NoopTranslator()); - Predictor predDest = - test.newPredictor(new NoopTranslator())) { - NDList input = new NDList(manager.ones(new Shape(1, 3, 32, 32))); - NDList originOut = predOrigin.predict(input); - NDList destOut = predDest.predict(input); - Assertions.assertAlmostEquals(originOut, destOut); - } - } - } - model.close(); - } - @Test public void testBlockFactoryLoadingFromZip() throws MalformedModelException, ModelNotFoundException, IOException, @@ -97,9 +59,9 @@ public void testBlockFactoryLoadingFromZip() .optModelPath(zipPath) .optModelName("exported") .build(); - try (NDManager manager = NDManager.newBaseManager(); - ZooModel model = criteria.loadModel(); + try (ZooModel model = criteria.loadModel(); Predictor pred = model.newPredictor()) { + NDManager manager = model.getNDManager(); NDList destOut = pred.predict(new NDList(manager.ones(new Shape(1, 3, 32, 32)))); Assert.assertEquals(destOut.singletonOrThrow().getShape(), new Shape(1, 10)); } @@ -136,9 +98,9 @@ public static class TestBlockFactory implements BlockFactory { private static final long serialVersionUID = 1234567L; @Override - public Block newBlock(NDManager manager) { + public Block newBlock(Model model, Path modelPath, Map arguments) { SequentialBlock newBlock = new SequentialBlock(); - newBlock.add(SymbolBlock.newInstance(manager)); + newBlock.add(SymbolBlock.newInstance(model.getNDManager())); newBlock.add(Linear.builder().setUnits(10).build()); return newBlock; } diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java index 7a34564ce9c..586f0a40fc2 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java @@ -12,10 +12,10 @@ */ package ai.djl.basicmodelzoo; -import ai.djl.basicmodelzoo.cv.classification.MlpModelLoader; -import ai.djl.basicmodelzoo.cv.classification.ResNetModelLoader; -import ai.djl.basicmodelzoo.cv.object_detection.ssd.SsdModelLoader; +import ai.djl.modality.cv.zoo.ImageClassificationModelLoader; +import ai.djl.modality.cv.zoo.ObjectDetectionModelLoader; import ai.djl.repository.Repository; +import ai.djl.repository.zoo.ModelLoader; import ai.djl.repository.zoo.ModelZoo; import java.util.HashSet; import java.util.Set; @@ -25,11 +25,15 @@ public class BasicModelZoo implements ModelZoo { private static final String REPO_URL = "https://mlrepo.djl.ai/"; private static final Repository REPOSITORY = Repository.newInstance("zoo", REPO_URL); + private static final ModelZoo ZOO = new BasicModelZoo(); public static final String GROUP_ID = "ai.djl.zoo"; - public static final ResNetModelLoader RESNET = new ResNetModelLoader(REPOSITORY); - public static final MlpModelLoader MLP = new MlpModelLoader(REPOSITORY); - public static final SsdModelLoader SSD = new SsdModelLoader(REPOSITORY); + public static final ModelLoader RESNET = + new ImageClassificationModelLoader(REPOSITORY, GROUP_ID, "resnet", "0.0.2", ZOO); + public static final ModelLoader MLP = + new ImageClassificationModelLoader(REPOSITORY, GROUP_ID, "mlp", "0.0.3", ZOO); + public static final ModelLoader SSD = + new ObjectDetectionModelLoader(REPOSITORY, GROUP_ID, "ssd", "0.0.2", ZOO); /** {@inheritDoc} */ @Override diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/basic/MlpBlockFactory.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/basic/MlpBlockFactory.java new file mode 100644 index 00000000000..ff550de2bb5 --- /dev/null +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/basic/MlpBlockFactory.java @@ -0,0 +1,49 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.basicmodelzoo.basic; + +import ai.djl.Model; +import ai.djl.nn.Block; +import ai.djl.nn.BlockFactory; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +/** A {@link BlockFactory} class that creates MLP block. */ +public class MlpBlockFactory implements BlockFactory { + + private static final long serialVersionUID = 1L; + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public Block newBlock(Model model, Path modelPath, Map arguments) { + Double width = (Double) arguments.get("width"); + if (width == null) { + width = 28d; + } + Double height = (Double) arguments.get("height"); + if (height == null) { + height = 28d; + } + int input = width.intValue() * height.intValue(); + int output = ((Double) arguments.get("output")).intValue(); + int[] hidden = + ((List) arguments.get("hidden")) + .stream() + .mapToInt(Double::intValue) + .toArray(); + + return new Mlp(input, output, hidden); + } +} diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/MlpModelLoader.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/MlpModelLoader.java deleted file mode 100644 index 3e33fb18ed3..00000000000 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/MlpModelLoader.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.basicmodelzoo.cv.classification; - -import ai.djl.Device; -import ai.djl.Model; -import ai.djl.basicmodelzoo.BasicModelZoo; -import ai.djl.basicmodelzoo.basic.Mlp; -import ai.djl.modality.cv.zoo.ImageClassificationModelLoader; -import ai.djl.repository.Artifact; -import ai.djl.repository.Repository; -import java.util.List; -import java.util.Map; - -/** Model loader for MLP models. */ -public class MlpModelLoader extends ImageClassificationModelLoader { - - private static final String GROUP_ID = BasicModelZoo.GROUP_ID; - private static final String ARTIFACT_ID = "mlp"; - private static final String VERSION = "0.0.3"; - - /** - * Creates the Model loader from the given repository. - * - * @param repository the repository to load the model from - */ - public MlpModelLoader(Repository repository) { - super(repository, GROUP_ID, ARTIFACT_ID, VERSION, new BasicModelZoo()); - } - - /** {@inheritDoc} */ - @Override - protected Model createModel( - String name, - Device device, - Artifact artifact, - Map arguments, - String engine) { - int width = ((Double) arguments.getOrDefault("width", 28d)).intValue(); - int height = ((Double) arguments.getOrDefault("height", 28d)).intValue(); - int input = width * height; - int output = ((Double) arguments.get("output")).intValue(); - @SuppressWarnings("unchecked") - int[] hidden = - ((List) arguments.get("hidden")) - .stream() - .mapToInt(Double::intValue) - .toArray(); - - Model model = Model.newInstance(name, device, engine); - model.setBlock(new Mlp(input, output, hidden)); - return model; - } -} diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/ResNetModelLoader.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/ResNetModelLoader.java deleted file mode 100644 index 0cd99c6a09b..00000000000 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/ResNetModelLoader.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.basicmodelzoo.cv.classification; - -import ai.djl.Device; -import ai.djl.Model; -import ai.djl.basicmodelzoo.BasicModelZoo; -import ai.djl.basicmodelzoo.cv.classification.ResNetV1.Builder; -import ai.djl.modality.cv.zoo.ImageClassificationModelLoader; -import ai.djl.ndarray.types.Shape; -import ai.djl.nn.Block; -import ai.djl.repository.Artifact; -import ai.djl.repository.Repository; -import java.util.List; -import java.util.Map; - -/** Model loader for ResNet_V1. */ -public class ResNetModelLoader extends ImageClassificationModelLoader { - - private static final String GROUP_ID = BasicModelZoo.GROUP_ID; - private static final String ARTIFACT_ID = "resnet"; - private static final String VERSION = "0.0.2"; - - /** - * Creates the Model loader from the given repository. - * - * @param repository the repository to load the model from - */ - public ResNetModelLoader(Repository repository) { - super(repository, GROUP_ID, ARTIFACT_ID, VERSION, new BasicModelZoo()); - } - - /** {@inheritDoc} */ - @Override - protected Model createModel( - String name, - Device device, - Artifact artifact, - Map arguments, - String engine) { - Model model = Model.newInstance(name, device, engine); - model.setBlock(resnetBlock(arguments)); - return model; - } - - private Block resnetBlock(Map arguments) { - @SuppressWarnings("unchecked") - Shape shape = - new Shape( - ((List) arguments.get("imageShape")) - .stream() - .mapToLong(Double::longValue) - .toArray()); - Builder blockBuilder = - ResNetV1.builder() - .setNumLayers((int) ((double) arguments.get("numLayers"))) - .setOutSize((long) ((double) arguments.get("outSize"))) - .setImageShape(shape); - if (arguments.containsKey("batchNormMomentum")) { - float batchNormMomentum = (float) ((double) arguments.get("batchNormMomentum")); - blockBuilder.optBatchNormMomentum(batchNormMomentum); - } - return blockBuilder.build(); - } -} diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/ResnetBlockFactory.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/ResnetBlockFactory.java new file mode 100644 index 00000000000..718c6ce3113 --- /dev/null +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/ResnetBlockFactory.java @@ -0,0 +1,49 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.basicmodelzoo.cv.classification; + +import ai.djl.Model; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Block; +import ai.djl.nn.BlockFactory; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +/** A {@link BlockFactory} class that creates {@link ResNetV1} block. */ +public class ResnetBlockFactory implements BlockFactory { + + private static final long serialVersionUID = 1L; + + /** {@inheritDoc} */ + @Override + public Block newBlock(Model model, Path modelPath, Map arguments) { + @SuppressWarnings("unchecked") + Shape shape = + new Shape( + ((List) arguments.get("imageShape")) + .stream() + .mapToLong(Double::longValue) + .toArray()); + ResNetV1.Builder blockBuilder = + ResNetV1.builder() + .setNumLayers(((Double) arguments.get("numLayers")).intValue()) + .setOutSize(((Double) arguments.get("outSize")).longValue()) + .setImageShape(shape); + if (arguments.containsKey("batchNormMomentum")) { + float batchNormMomentum = ((Double) arguments.get("batchNormMomentum")).floatValue(); + blockBuilder.optBatchNormMomentum(batchNormMomentum); + } + return blockBuilder.build(); + } +} diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SsdModelLoader.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SsdBlockFactory.java similarity index 66% rename from model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SsdModelLoader.java rename to model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SsdBlockFactory.java index a75187e4312..1a3a97ec18c 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SsdModelLoader.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SsdBlockFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at @@ -10,56 +10,30 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ - package ai.djl.basicmodelzoo.cv.object_detection.ssd; -import ai.djl.Device; import ai.djl.Model; -import ai.djl.basicmodelzoo.BasicModelZoo; -import ai.djl.modality.cv.zoo.ObjectDetectionModelLoader; import ai.djl.nn.Block; +import ai.djl.nn.BlockFactory; import ai.djl.nn.SequentialBlock; -import ai.djl.repository.Artifact; -import ai.djl.repository.Repository; +import java.nio.file.Path; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.stream.Collectors; -/** Model loader for SingleShotDetection(SSD). */ -public class SsdModelLoader extends ObjectDetectionModelLoader { - - private static final String GROUP_ID = BasicModelZoo.GROUP_ID; - private static final String ARTIFACT_ID = "ssd"; - private static final String VERSION = "0.0.2"; +/** A {@link BlockFactory} class that creates {@link SingleShotDetection} block. */ +public class SsdBlockFactory implements BlockFactory { - /** - * Creates the Model loader from the given repository. - * - * @param repository the repository to load the model from - */ - public SsdModelLoader(Repository repository) { - super(repository, GROUP_ID, ARTIFACT_ID, VERSION, new BasicModelZoo()); - } + private static final long serialVersionUID = 1L; /** {@inheritDoc} */ @Override - protected Model createModel( - String name, - Device device, - Artifact artifact, - Map arguments, - String engine) { - Model model = Model.newInstance(name, device, engine); - model.setBlock(customSSDBlock(arguments)); - return model; - } - @SuppressWarnings("unchecked") - private Block customSSDBlock(Map arguments) { + public Block newBlock(Model model, Path modelPath, Map arguments) { int numClasses = ((Double) arguments.get("outSize")).intValue(); int numFeatures = ((Double) arguments.get("numFeatures")).intValue(); - boolean globalPool = (boolean) arguments.get("globalPool"); + boolean globalPool = (Boolean) arguments.get("globalPool"); int[] numFilters = ((List) arguments.get("numFilters")) .stream() diff --git a/model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/zoo/mlp/metadata.json b/model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/zoo/mlp/metadata.json index 665dadbcbfe..d313073c39a 100644 --- a/model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/zoo/mlp/metadata.json +++ b/model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/zoo/mlp/metadata.json @@ -57,6 +57,7 @@ "width": 28, "height": 28, "output": 10, + "blockFactory": "ai.djl.basicmodelzoo.basic.MlpBlockFactory", "hidden": [ 128, 64 diff --git a/model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/zoo/resnet/metadata.json b/model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/zoo/resnet/metadata.json index ef96a95de5b..cb6f1f03b1c 100644 --- a/model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/zoo/resnet/metadata.json +++ b/model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/zoo/resnet/metadata.json @@ -30,6 +30,7 @@ "resize": true, "numLayers": 50, "outSize": 10, + "blockFactory": "ai.djl.basicmodelzoo.cv.classification.ResnetBlockFactory", "imageShape": [ 3, 32, @@ -66,6 +67,7 @@ "resize": true, "numLayers": 50, "outSize": 10, + "blockFactory": "ai.djl.basicmodelzoo.cv.classification.ResnetBlockFactory", "imageShape": [ 3, 32, diff --git a/model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/zoo/ssd/metadata.json b/model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/zoo/ssd/metadata.json index 8a0bd7bdb79..44dbc5df780 100644 --- a/model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/zoo/ssd/metadata.json +++ b/model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/zoo/ssd/metadata.json @@ -88,6 +88,7 @@ "width": 256, "height": 256, "numFeatures": 3, + "blockFactory": "ai.djl.basicmodelzoo.cv.object_detection.ssd.SsdBlockFactory", "numFilters": [ 16, 32, diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java index 839bf8703aa..5c675a9519c 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java @@ -100,10 +100,6 @@ public void load(Path modelPath, String prefix, Map options) } } - if (block == null) { - block = loadFromBlockFactory(); - } - if (block == null) { // load MxSymbolBlock Path symbolFile = modelDir.resolve(prefix + "-symbol.json"); diff --git a/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/embedding/GloveWordEmbeddingBlockFactory.java b/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/embedding/GloveWordEmbeddingBlockFactory.java new file mode 100644 index 00000000000..3ef95acc209 --- /dev/null +++ b/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/embedding/GloveWordEmbeddingBlockFactory.java @@ -0,0 +1,47 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.mxnet.zoo.nlp.embedding; + +import ai.djl.Model; +import ai.djl.modality.nlp.SimpleVocabulary; +import ai.djl.modality.nlp.embedding.TrainableWordEmbedding; +import ai.djl.nn.Block; +import ai.djl.nn.BlockFactory; +import ai.djl.util.Utils; +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +/** A {@link BlockFactory} class that creates Glove word embedding block. */ +public class GloveWordEmbeddingBlockFactory implements BlockFactory { + + private static final long serialVersionUID = 1L; + + /** {@inheritDoc} */ + @Override + public Block newBlock(Model model, Path modelPath, Map arguments) + throws IOException { + List idxToToken = Utils.readLines(modelPath.resolve("idx_to_token.txt")); + String dimension = (String) arguments.get("dimensions"); + TrainableWordEmbedding wordEmbedding = + TrainableWordEmbedding.builder() + .optNumEmbeddings(Integer.parseInt(dimension)) + .setVocabulary(new SimpleVocabulary(idxToToken)) + .optUnknownToken((String) arguments.get("unknownToken")) + .optUseDefault(true) + .build(); + model.setProperty("unknownToken", (String) arguments.get("unknownToken")); + return wordEmbedding; + } +} diff --git a/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/embedding/GloveWordEmbeddingModelLoader.java b/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/embedding/GloveWordEmbeddingModelLoader.java index 1f50f12b472..2cafdc4abb4 100644 --- a/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/embedding/GloveWordEmbeddingModelLoader.java +++ b/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/nlp/embedding/GloveWordEmbeddingModelLoader.java @@ -14,17 +14,13 @@ import ai.djl.Application; import ai.djl.Application.NLP; -import ai.djl.Device; import ai.djl.MalformedModelException; import ai.djl.Model; -import ai.djl.modality.nlp.SimpleVocabulary; -import ai.djl.modality.nlp.embedding.TrainableWordEmbedding; import ai.djl.modality.nlp.embedding.WordEmbedding; import ai.djl.mxnet.zoo.MxModelZoo; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.nn.core.Embedding; -import ai.djl.repository.Artifact; import ai.djl.repository.MRL; import ai.djl.repository.Repository; import ai.djl.repository.zoo.BaseModelLoader; @@ -36,9 +32,7 @@ import ai.djl.translate.TranslatorContext; import ai.djl.translate.TranslatorFactory; import ai.djl.util.Pair; -import ai.djl.util.Utils; import java.io.IOException; -import java.util.List; import java.util.Map; /** @@ -62,38 +56,6 @@ public GloveWordEmbeddingModelLoader(Repository repository) { factories.put(new Pair<>(String.class, NDList.class), new FactoryImpl()); } - private Model customGloveBlock(Model model, Artifact artifact, Map arguments) - throws IOException { - List idxToToken = - Utils.readLines( - resource.getRepository() - .openStream(artifact.getFiles().get("idx_to_token"), null)); - TrainableWordEmbedding wordEmbedding = - TrainableWordEmbedding.builder() - .optNumEmbeddings( - Integer.parseInt(artifact.getProperties().get("dimensions"))) - .setVocabulary(new SimpleVocabulary(idxToToken)) - .optUnknownToken((String) arguments.get("unknownToken")) - .optUseDefault(true) - .build(); - model.setBlock(wordEmbedding); - model.setProperty("unknownToken", (String) arguments.get("unknownToken")); - return model; - } - - /** {@inheritDoc} */ - @Override - protected Model createModel( - String name, - Device device, - Artifact artifact, - Map arguments, - String engine) - throws IOException { - Model model = Model.newInstance(name, device, engine); - return customGloveBlock(model, artifact, arguments); - } - /** * Loads the model with the given search filters. * diff --git a/mxnet/mxnet-model-zoo/src/test/resources/mlrepo/model/nlp/word_embedding/ai/djl/mxnet/glove/metadata.json b/mxnet/mxnet-model-zoo/src/test/resources/mlrepo/model/nlp/word_embedding/ai/djl/mxnet/glove/metadata.json index 9debfa6b035..e856c127994 100644 --- a/mxnet/mxnet-model-zoo/src/test/resources/mlrepo/model/nlp/word_embedding/ai/djl/mxnet/glove/metadata.json +++ b/mxnet/mxnet-model-zoo/src/test/resources/mlrepo/model/nlp/word_embedding/ai/djl/mxnet/glove/metadata.json @@ -22,7 +22,9 @@ "dimensions": "50" }, "arguments": { - "unknownToken": "\u003cunk\u003e" + "dimensions": "50", + "unknownToken": "\u003cunk\u003e", + "blockFactory": "ai.djl.mxnet.zoo.nlp.embedding.GloveWordEmbeddingBlockFactory" }, "files": { "parameters": { @@ -45,7 +47,9 @@ "dimensions": "50" }, "arguments": { - "unknownToken": "\u003cunk\u003e" + "dimensions": "50", + "unknownToken": "\u003cunk\u003e", + "blockFactory": "ai.djl.mxnet.zoo.nlp.embedding.GloveWordEmbeddingBlockFactory" }, "files": { "parameters": { diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index 697607b0078..d31db77dbeb 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -65,10 +65,6 @@ public void load(Path modelPath, String prefix, Map options) prefix = modelName; } - if (block == null) { - block = loadFromBlockFactory(); - } - if (block == null) { Path modelFile = findModelFile(prefix); if (modelFile == null) {