diff --git a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java index 4388d68b765..4a405308b1f 100644 --- a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java +++ b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java @@ -296,6 +296,8 @@ NDList deconvolution( NDList dropout(NDArray input, float rate, boolean training); + NDList layerNorm(NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps); + NDList batchNorm( NDArray input, NDArray runningMean, diff --git a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java new file mode 100644 index 00000000000..0861e1fa72f --- /dev/null +++ b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java @@ -0,0 +1,248 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.nn.norm; + +import ai.djl.Device; +import ai.djl.MalformedModelException; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.internal.NDArrayEx; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.AbstractBlock; +import ai.djl.nn.Parameter; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Arrays; + +/** + * Layer normalization works by normalizing the values of input data for each input sample to have + * mean of 0 and variance of 1. Since this may alter the representation of a layer, two parameters + * (\ (\gamma\) and \(\beta\)) are learned along the normalization process to respectively scale and + * shift the normalized output (activations) to have any mean and variance so the network can + * utilize non-linear transformations such as sigmoid function as described in the paper. During backpropagation, both \(\gamma\) and + * \(\beta\) parameters are included following the chain-rule in derivation. + * + *

Citing the abstract of the paper: "Training state-of-the-art, deep neural networks is + * computationally expensive. One way to reduce the training time is to normalize the activities of + * the neurons. A recently introduced technique called batch normalization uses the distribution of + * the summed input to a neuron over a mini-batch of training cases to compute a mean and variance + * which are then used to normalize the summed input to that neuron on each training case. This + * significantly reduces the training time in feed-forward neural networks. However, the effect of + * batch normalization is dependent on the mini-batch size and it is not obvious how to apply it to + * recurrent neural networks. In this paper, we transpose batch normalization into layer + * normalization by computing the mean and variance used for normalization from all of the summed + * inputs to the neurons in a layer on a single training case. Like batch normalization, we also + * give each neuron its own adaptive bias and gain which are applied after the normalization but + * before the non-linearity. Unlike batch normalization, layer normalization performs exactly the + * same computation at training and test times. It is also straightforward to apply to recurrent + * neural networks by computing the normalization statistics separately at each time step. Layer + * normalization is very effective at stabilizing the hidden state dynamics in recurrent networks. + * Empirically, we show that layer normalization can substantially reduce the training time compared + * with previously published techniques." + */ +public class LayerNorm extends AbstractBlock { + + private static final byte VERSION = 1; + + private float epsilon; + private Shape normalizedShape; + + private boolean center; + private boolean scale; + private int[] axis; + private Parameter gamma; + private Parameter beta; + + LayerNorm(Builder builder) { + super(VERSION); + epsilon = builder.epsilon; + scale = builder.scale; + center = builder.center; + axis = builder.axis; + + // make gamma trainable if scale + gamma = + addParameter( + Parameter.builder() + .setName("gamma") + .setType(Parameter.Type.GAMMA) + .optRequiresGrad(scale) + .build()); + // make beta trainable if center + beta = + addParameter( + Parameter.builder() + .setName("beta") + .setType(Parameter.Type.BETA) + .optRequiresGrad(center) + .build()); + } + + /** + * Applies Layer Normalization with average and variance for each input sample across the axis + * dimensions. + * + * @param input the input {@code NDArray} of shape (batchSize, inputChannel, *), * could be + * empty, width, (height, width), (depth, height, width) + * @param normalizedShape dimensions to calculate average and variance from + * @param gamma gamma weight {@code NDArray} + * @param beta beta weight {@code NDArray} + * @param eps a value added to the denominator for numerical stability + * @return the output {@code NDArray} of shape (batchSize, inputChannel, *), * could be empty, + * width, (height, width), (depth, height, width) + */ + public static NDList layerNorm( + NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) { + NDArrayEx ex = input.getNDArrayInternal(); + return ex.layerNorm(input, normalizedShape, gamma, beta, eps); + } + + /** + * Creates a builder to build a {@code LayerNorm}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** {@inheritDoc} */ + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + NDArray input = inputs.singletonOrThrow(); + Device device = input.getDevice(); + NDArray gammaArr = parameterStore.getValue(gamma, device, training); + NDArray betaArr = parameterStore.getValue(beta, device, training); + + return layerNorm(input, normalizedShape, gammaArr, betaArr, epsilon); + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputShapes) { + return new Shape[] {inputShapes[0]}; + } + + /** {@inheritDoc} */ + @Override + protected void beforeInitialize(Shape... inputShapes) { + super.beforeInitialize(inputShapes); + normalizedShape = + axis == null + ? inputShapes[0].slice(1) + : new Shape( + Arrays.stream(axis) + .mapToLong(dim -> inputShapes[0].get(dim)) + .toArray()); + } + + /** {@inheritDoc} */ + @Override + public void prepare(Shape[] inputShapes) { + gamma.setShape(normalizedShape); + beta.setShape(normalizedShape); + } + + /** {@inheritDoc} */ + @Override + protected void saveMetadata(DataOutputStream os) throws IOException { + saveInputShapes(os); + os.write(normalizedShape.getEncoded()); + } + + /** {@inheritDoc} */ + @Override + public void loadMetadata(byte version, DataInputStream is) + throws IOException, MalformedModelException { + if (version != VERSION) { + throw new MalformedModelException("Unsupported encoding version: " + version); + } + readInputShapes(is); + normalizedShape = Shape.decode(is); + } + + /** The Builder to construct a {@link LayerNorm}. */ + public static final class Builder { + + private float epsilon = 1E-5f; + // private Shape normalizedShape; + private boolean scale = true; + private boolean center = true; + private int[] axis; + + Builder() {} + + /** + * List the axis over which the mean and variance will be calculated (alternative to + * normalizedShape). + * + * @param axis input axis over which the mean and variance will be calculated (if null all + * existing dimensions) + * @return this Builder + */ + public Builder axis(int... axis) { + this.axis = axis; + return this; + } + + /** + * If True, add offset of `beta` to normalized tensor. Defaults to True. + * + * @param val True or False on whether to add and train offset value + * @return this Builder + */ + public Builder optCenter(boolean val) { + center = val; + return this; + } + + /** + * If True, multiply result by `gamma`. Defaults to True; + * + * @param val True or False on whether to add and train scale value + * @return this Builder + */ + public Builder optScale(boolean val) { + scale = val; + return this; + } + + /** + * Sets the epsilon value to prevent division by 0. + * + * @param val the epsilon value + * @return this Builder + */ + public Builder optEpsilon(float val) { + epsilon = val; + return this; + } + + /** + * Builds a {@link LayerNorm} block. + * + * @return the {@link LayerNorm} block + */ + public LayerNorm build() { + return new LayerNorm(this); + } + } +} diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java index c133b382327..e577f2a3d8a 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java @@ -12,6 +12,7 @@ */ package ai.djl.integration.tests.nn; +import ai.djl.Device; import ai.djl.MalformedModelException; import ai.djl.Model; import ai.djl.engine.Engine; @@ -36,6 +37,7 @@ import ai.djl.nn.core.Linear; import ai.djl.nn.norm.BatchNorm; import ai.djl.nn.norm.Dropout; +import ai.djl.nn.norm.LayerNorm; import ai.djl.nn.recurrent.GRU; import ai.djl.nn.recurrent.LSTM; import ai.djl.nn.recurrent.RNN; @@ -202,6 +204,60 @@ public void testBatchNorm() throws IOException, MalformedModelException { } } + @SuppressWarnings("try") + @Test + public void testLayerNorm() throws IOException, MalformedModelException { + TrainingConfig config = + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); + + Block block = LayerNorm.builder().build(); + try (Model model = Model.newInstance("model", Device.cpu(), "PyTorch")) { + model.setBlock(block); + + try (Trainer trainer = model.newTrainer(config)) { + try (GradientCollector collector = trainer.newGradientCollector()) { + Shape inputShape = new Shape(2, 2); + trainer.initialize(inputShape); + + NDManager manager = trainer.getManager(); + NDArray data = manager.create(new float[] {1, 3, 2, 4}, inputShape); + NDArray expected = manager.create(new float[] {-1, 1, -1, 1}, inputShape); + NDArray result = trainer.forward(new NDList(data)).singletonOrThrow(); + Assertions.assertAlmostEquals(result, expected); + testEncode(manager, block); + } + } + } + } + + @SuppressWarnings("try") + @Test + public void test2LayerNorm() throws IOException, MalformedModelException { + TrainingConfig config = + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); + + Block block = LayerNorm.builder().axis(2, 3).build(); + try (Model model = Model.newInstance("model", Device.cpu(), "PyTorch")) { + model.setBlock(block); + + try (Trainer trainer = model.newTrainer(config)) { + try (GradientCollector collector = trainer.newGradientCollector()) { + Shape inputShape = new Shape(1, 2, 1, 2); + trainer.initialize(inputShape); + + NDManager manager = trainer.getManager(); + NDArray data = manager.create(new float[] {1, 3, 2, 4}, inputShape); + NDArray expected = manager.create(new float[] {-1, 1, -1, 1}, inputShape); + NDArray result = trainer.forward(new NDList(data)).singletonOrThrow(); + Assertions.assertAlmostEquals(result, expected); + testEncode(manager, block); + } + } + } + } + @SuppressWarnings("try") @Test public void testDropout() throws IOException, MalformedModelException { diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java index 7437ba0d2ef..31feabb58e2 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java @@ -636,6 +636,13 @@ public NDList dropout(NDArray input, float rate, boolean training) { return getManager().invoke("_npx_dropout", new NDList(input), params); } + /** {@inheritDoc} */ + @Override + public NDList layerNorm( + NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) { + throw new UnsupportedOperationException(); + } + /** {@inheritDoc} */ @Override public NDList batchNorm( diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java index 6c2dd0488ae..bce4a877240 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java @@ -411,6 +411,18 @@ public NDList dropout(NDArray input, float rate, boolean training) { return new NDList(JniUtils.dropout((PtNDArray) input, rate, training)); } + /** {@inheritDoc} */ + @Override + public NDList layerNorm( + NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) { + return new NDList( + JniUtils.layerNorm( + (PtNDArray) input, + normalizedShape, + (PtNDArray) gamma, + (PtNDArray) beta, + eps)); + } /** {@inheritDoc} */ @Override public NDList batchNorm( diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index a111b89aa0c..db613bae3c8 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -1094,6 +1094,18 @@ public static PtNDArray batchNorm( eps)); } + public static PtNDArray layerNorm( + PtNDArray ndArray, Shape normalizedShape, PtNDArray gamma, PtNDArray beta, double eps) { + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchNNLayerNorm( + ndArray.getHandle(), + normalizedShape.getShape(), + gamma.getHandle(), + beta.getHandle(), + eps)); + } + public static PtNDArray dropout(PtNDArray ndArray, double prob, boolean training) { return new PtNDArray( ndArray.getManager(), diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 0bb7dbafccf..c078e7c7cef 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -386,6 +386,13 @@ native long torchNNConvNd( native long torchNNDropout(long inputHandle, double probability, boolean isTrain); + native long torchNNLayerNorm( + long inputHandle, + long[] normalizedShape, + long weigthHandle, + long biasHandle, + double eps); + native long torchNNBatchNorm( long inputHandle, long runningMeanHandle, diff --git a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc index 59afabcecf5..84a7dfd6334 100644 --- a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc +++ b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc @@ -144,6 +144,25 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNBatchNorm( API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNLayerNorm( + JNIEnv* env, jobject jthis, jlong jinput, jlongArray jnormalizedshape, jlong jweight, jlong jbias, jdouble jeps) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jinput); + const auto normalized_shape_vec = djl::utils::jni::GetVecFromJLongArray(env, jnormalizedshape); + torch::Tensor weight = {}; + torch::Tensor bias = {}; + if (jweight != djl::utils::jni::NULL_PTR) { + weight = *reinterpret_cast(jweight); + } + if (jbias != djl::utils::jni::NULL_PTR) { + bias = *reinterpret_cast(jbias); + } + const auto* result_ptr = new torch::Tensor(torch::nn::functional::layer_norm(*tensor_ptr, + torch::nn::functional::LayerNormFuncOptions(normalized_shape_vec).weight(weight).bias(bias).eps(jeps))); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNDropout( JNIEnv* env, jobject jthis, jlong jinput, jdouble probability, jboolean jtraining) { API_BEGIN() diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java index b5de2b3942d..579a85ac63f 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java @@ -378,6 +378,12 @@ public NDList dropout(NDArray input, float rate, boolean training) { throw new UnsupportedOperationException("Not implemented"); } + /** {@inheritDoc} */ + @Override + public NDList layerNorm( + NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) { + throw new UnsupportedOperationException(); + } /** {@inheritDoc} */ @Override public NDList batchNorm(