diff --git a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java
index 4388d68b765..4a405308b1f 100644
--- a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java
+++ b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java
@@ -296,6 +296,8 @@ NDList deconvolution(
NDList dropout(NDArray input, float rate, boolean training);
+ NDList layerNorm(NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps);
+
NDList batchNorm(
NDArray input,
NDArray runningMean,
diff --git a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java
new file mode 100644
index 00000000000..0861e1fa72f
--- /dev/null
+++ b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java
@@ -0,0 +1,248 @@
+/*
+ * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.nn.norm;
+
+import ai.djl.Device;
+import ai.djl.MalformedModelException;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.internal.NDArrayEx;
+import ai.djl.ndarray.types.Shape;
+import ai.djl.nn.AbstractBlock;
+import ai.djl.nn.Parameter;
+import ai.djl.training.ParameterStore;
+import ai.djl.util.PairList;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Arrays;
+
+/**
+ * Layer normalization works by normalizing the values of input data for each input sample to have
+ * mean of 0 and variance of 1. Since this may alter the representation of a layer, two parameters
+ * (\ (\gamma\) and \(\beta\)) are learned along the normalization process to respectively scale and
+ * shift the normalized output (activations) to have any mean and variance so the network can
+ * utilize non-linear transformations such as sigmoid function as described in the paper. During backpropagation, both \(\gamma\) and
+ * \(\beta\) parameters are included following the chain-rule in derivation.
+ *
+ *
Citing the abstract of the paper: "Training state-of-the-art, deep neural networks is
+ * computationally expensive. One way to reduce the training time is to normalize the activities of
+ * the neurons. A recently introduced technique called batch normalization uses the distribution of
+ * the summed input to a neuron over a mini-batch of training cases to compute a mean and variance
+ * which are then used to normalize the summed input to that neuron on each training case. This
+ * significantly reduces the training time in feed-forward neural networks. However, the effect of
+ * batch normalization is dependent on the mini-batch size and it is not obvious how to apply it to
+ * recurrent neural networks. In this paper, we transpose batch normalization into layer
+ * normalization by computing the mean and variance used for normalization from all of the summed
+ * inputs to the neurons in a layer on a single training case. Like batch normalization, we also
+ * give each neuron its own adaptive bias and gain which are applied after the normalization but
+ * before the non-linearity. Unlike batch normalization, layer normalization performs exactly the
+ * same computation at training and test times. It is also straightforward to apply to recurrent
+ * neural networks by computing the normalization statistics separately at each time step. Layer
+ * normalization is very effective at stabilizing the hidden state dynamics in recurrent networks.
+ * Empirically, we show that layer normalization can substantially reduce the training time compared
+ * with previously published techniques."
+ */
+public class LayerNorm extends AbstractBlock {
+
+ private static final byte VERSION = 1;
+
+ private float epsilon;
+ private Shape normalizedShape;
+
+ private boolean center;
+ private boolean scale;
+ private int[] axis;
+ private Parameter gamma;
+ private Parameter beta;
+
+ LayerNorm(Builder builder) {
+ super(VERSION);
+ epsilon = builder.epsilon;
+ scale = builder.scale;
+ center = builder.center;
+ axis = builder.axis;
+
+ // make gamma trainable if scale
+ gamma =
+ addParameter(
+ Parameter.builder()
+ .setName("gamma")
+ .setType(Parameter.Type.GAMMA)
+ .optRequiresGrad(scale)
+ .build());
+ // make beta trainable if center
+ beta =
+ addParameter(
+ Parameter.builder()
+ .setName("beta")
+ .setType(Parameter.Type.BETA)
+ .optRequiresGrad(center)
+ .build());
+ }
+
+ /**
+ * Applies Layer Normalization with average and variance for each input sample across the axis
+ * dimensions.
+ *
+ * @param input the input {@code NDArray} of shape (batchSize, inputChannel, *), * could be
+ * empty, width, (height, width), (depth, height, width)
+ * @param normalizedShape dimensions to calculate average and variance from
+ * @param gamma gamma weight {@code NDArray}
+ * @param beta beta weight {@code NDArray}
+ * @param eps a value added to the denominator for numerical stability
+ * @return the output {@code NDArray} of shape (batchSize, inputChannel, *), * could be empty,
+ * width, (height, width), (depth, height, width)
+ */
+ public static NDList layerNorm(
+ NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) {
+ NDArrayEx ex = input.getNDArrayInternal();
+ return ex.layerNorm(input, normalizedShape, gamma, beta, eps);
+ }
+
+ /**
+ * Creates a builder to build a {@code LayerNorm}.
+ *
+ * @return a new builder
+ */
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ protected NDList forwardInternal(
+ ParameterStore parameterStore,
+ NDList inputs,
+ boolean training,
+ PairList params) {
+ NDArray input = inputs.singletonOrThrow();
+ Device device = input.getDevice();
+ NDArray gammaArr = parameterStore.getValue(gamma, device, training);
+ NDArray betaArr = parameterStore.getValue(beta, device, training);
+
+ return layerNorm(input, normalizedShape, gammaArr, betaArr, epsilon);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
+ return new Shape[] {inputShapes[0]};
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ protected void beforeInitialize(Shape... inputShapes) {
+ super.beforeInitialize(inputShapes);
+ normalizedShape =
+ axis == null
+ ? inputShapes[0].slice(1)
+ : new Shape(
+ Arrays.stream(axis)
+ .mapToLong(dim -> inputShapes[0].get(dim))
+ .toArray());
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void prepare(Shape[] inputShapes) {
+ gamma.setShape(normalizedShape);
+ beta.setShape(normalizedShape);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ protected void saveMetadata(DataOutputStream os) throws IOException {
+ saveInputShapes(os);
+ os.write(normalizedShape.getEncoded());
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void loadMetadata(byte version, DataInputStream is)
+ throws IOException, MalformedModelException {
+ if (version != VERSION) {
+ throw new MalformedModelException("Unsupported encoding version: " + version);
+ }
+ readInputShapes(is);
+ normalizedShape = Shape.decode(is);
+ }
+
+ /** The Builder to construct a {@link LayerNorm}. */
+ public static final class Builder {
+
+ private float epsilon = 1E-5f;
+ // private Shape normalizedShape;
+ private boolean scale = true;
+ private boolean center = true;
+ private int[] axis;
+
+ Builder() {}
+
+ /**
+ * List the axis over which the mean and variance will be calculated (alternative to
+ * normalizedShape).
+ *
+ * @param axis input axis over which the mean and variance will be calculated (if null all
+ * existing dimensions)
+ * @return this Builder
+ */
+ public Builder axis(int... axis) {
+ this.axis = axis;
+ return this;
+ }
+
+ /**
+ * If True, add offset of `beta` to normalized tensor. Defaults to True.
+ *
+ * @param val True or False on whether to add and train offset value
+ * @return this Builder
+ */
+ public Builder optCenter(boolean val) {
+ center = val;
+ return this;
+ }
+
+ /**
+ * If True, multiply result by `gamma`. Defaults to True;
+ *
+ * @param val True or False on whether to add and train scale value
+ * @return this Builder
+ */
+ public Builder optScale(boolean val) {
+ scale = val;
+ return this;
+ }
+
+ /**
+ * Sets the epsilon value to prevent division by 0.
+ *
+ * @param val the epsilon value
+ * @return this Builder
+ */
+ public Builder optEpsilon(float val) {
+ epsilon = val;
+ return this;
+ }
+
+ /**
+ * Builds a {@link LayerNorm} block.
+ *
+ * @return the {@link LayerNorm} block
+ */
+ public LayerNorm build() {
+ return new LayerNorm(this);
+ }
+ }
+}
diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java
index c133b382327..e577f2a3d8a 100644
--- a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java
@@ -12,6 +12,7 @@
*/
package ai.djl.integration.tests.nn;
+import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.engine.Engine;
@@ -36,6 +37,7 @@
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.norm.Dropout;
+import ai.djl.nn.norm.LayerNorm;
import ai.djl.nn.recurrent.GRU;
import ai.djl.nn.recurrent.LSTM;
import ai.djl.nn.recurrent.RNN;
@@ -202,6 +204,60 @@ public void testBatchNorm() throws IOException, MalformedModelException {
}
}
+ @SuppressWarnings("try")
+ @Test
+ public void testLayerNorm() throws IOException, MalformedModelException {
+ TrainingConfig config =
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
+
+ Block block = LayerNorm.builder().build();
+ try (Model model = Model.newInstance("model", Device.cpu(), "PyTorch")) {
+ model.setBlock(block);
+
+ try (Trainer trainer = model.newTrainer(config)) {
+ try (GradientCollector collector = trainer.newGradientCollector()) {
+ Shape inputShape = new Shape(2, 2);
+ trainer.initialize(inputShape);
+
+ NDManager manager = trainer.getManager();
+ NDArray data = manager.create(new float[] {1, 3, 2, 4}, inputShape);
+ NDArray expected = manager.create(new float[] {-1, 1, -1, 1}, inputShape);
+ NDArray result = trainer.forward(new NDList(data)).singletonOrThrow();
+ Assertions.assertAlmostEquals(result, expected);
+ testEncode(manager, block);
+ }
+ }
+ }
+ }
+
+ @SuppressWarnings("try")
+ @Test
+ public void test2LayerNorm() throws IOException, MalformedModelException {
+ TrainingConfig config =
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
+
+ Block block = LayerNorm.builder().axis(2, 3).build();
+ try (Model model = Model.newInstance("model", Device.cpu(), "PyTorch")) {
+ model.setBlock(block);
+
+ try (Trainer trainer = model.newTrainer(config)) {
+ try (GradientCollector collector = trainer.newGradientCollector()) {
+ Shape inputShape = new Shape(1, 2, 1, 2);
+ trainer.initialize(inputShape);
+
+ NDManager manager = trainer.getManager();
+ NDArray data = manager.create(new float[] {1, 3, 2, 4}, inputShape);
+ NDArray expected = manager.create(new float[] {-1, 1, -1, 1}, inputShape);
+ NDArray result = trainer.forward(new NDList(data)).singletonOrThrow();
+ Assertions.assertAlmostEquals(result, expected);
+ testEncode(manager, block);
+ }
+ }
+ }
+ }
+
@SuppressWarnings("try")
@Test
public void testDropout() throws IOException, MalformedModelException {
diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java
index 7437ba0d2ef..31feabb58e2 100644
--- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java
+++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java
@@ -636,6 +636,13 @@ public NDList dropout(NDArray input, float rate, boolean training) {
return getManager().invoke("_npx_dropout", new NDList(input), params);
}
+ /** {@inheritDoc} */
+ @Override
+ public NDList layerNorm(
+ NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) {
+ throw new UnsupportedOperationException();
+ }
+
/** {@inheritDoc} */
@Override
public NDList batchNorm(
diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
index 6c2dd0488ae..bce4a877240 100644
--- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
+++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java
@@ -411,6 +411,18 @@ public NDList dropout(NDArray input, float rate, boolean training) {
return new NDList(JniUtils.dropout((PtNDArray) input, rate, training));
}
+ /** {@inheritDoc} */
+ @Override
+ public NDList layerNorm(
+ NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) {
+ return new NDList(
+ JniUtils.layerNorm(
+ (PtNDArray) input,
+ normalizedShape,
+ (PtNDArray) gamma,
+ (PtNDArray) beta,
+ eps));
+ }
/** {@inheritDoc} */
@Override
public NDList batchNorm(
diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java
index a111b89aa0c..db613bae3c8 100644
--- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java
+++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java
@@ -1094,6 +1094,18 @@ public static PtNDArray batchNorm(
eps));
}
+ public static PtNDArray layerNorm(
+ PtNDArray ndArray, Shape normalizedShape, PtNDArray gamma, PtNDArray beta, double eps) {
+ return new PtNDArray(
+ ndArray.getManager(),
+ PyTorchLibrary.LIB.torchNNLayerNorm(
+ ndArray.getHandle(),
+ normalizedShape.getShape(),
+ gamma.getHandle(),
+ beta.getHandle(),
+ eps));
+ }
+
public static PtNDArray dropout(PtNDArray ndArray, double prob, boolean training) {
return new PtNDArray(
ndArray.getManager(),
diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java
index 0bb7dbafccf..c078e7c7cef 100644
--- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java
+++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java
@@ -386,6 +386,13 @@ native long torchNNConvNd(
native long torchNNDropout(long inputHandle, double probability, boolean isTrain);
+ native long torchNNLayerNorm(
+ long inputHandle,
+ long[] normalizedShape,
+ long weigthHandle,
+ long biasHandle,
+ double eps);
+
native long torchNNBatchNorm(
long inputHandle,
long runningMeanHandle,
diff --git a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc
index 59afabcecf5..84a7dfd6334 100644
--- a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc
+++ b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc
@@ -144,6 +144,25 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNBatchNorm(
API_END_RETURN()
}
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNLayerNorm(
+ JNIEnv* env, jobject jthis, jlong jinput, jlongArray jnormalizedshape, jlong jweight, jlong jbias, jdouble jeps) {
+ API_BEGIN()
+ const auto* tensor_ptr = reinterpret_cast(jinput);
+ const auto normalized_shape_vec = djl::utils::jni::GetVecFromJLongArray(env, jnormalizedshape);
+ torch::Tensor weight = {};
+ torch::Tensor bias = {};
+ if (jweight != djl::utils::jni::NULL_PTR) {
+ weight = *reinterpret_cast(jweight);
+ }
+ if (jbias != djl::utils::jni::NULL_PTR) {
+ bias = *reinterpret_cast(jbias);
+ }
+ const auto* result_ptr = new torch::Tensor(torch::nn::functional::layer_norm(*tensor_ptr,
+ torch::nn::functional::LayerNormFuncOptions(normalized_shape_vec).weight(weight).bias(bias).eps(jeps)));
+ return reinterpret_cast(result_ptr);
+ API_END_RETURN()
+}
+
JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNDropout(
JNIEnv* env, jobject jthis, jlong jinput, jdouble probability, jboolean jtraining) {
API_BEGIN()
diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java
index b5de2b3942d..579a85ac63f 100644
--- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java
+++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java
@@ -378,6 +378,12 @@ public NDList dropout(NDArray input, float rate, boolean training) {
throw new UnsupportedOperationException("Not implemented");
}
+ /** {@inheritDoc} */
+ @Override
+ public NDList layerNorm(
+ NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) {
+ throw new UnsupportedOperationException();
+ }
/** {@inheritDoc} */
@Override
public NDList batchNorm(