generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 731
LayerNorm using PyTorch #1069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
LayerNorm using PyTorch #1069
Changes from 1 commit
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
a630f71
LayerNorm using PyTorch
enpasos 438f703
Update mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray…
enpasos 3faffb3
Update pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtN…
enpasos bd1c60c
Update tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/e…
enpasos 380fb3b
Version 1
enpasos b780542
save/loadMetadata fixes
enpasos ae54fc1
"formatCpp"
enpasos 1b1c97d
removed Line
enpasos File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,256 @@ | ||
| /* | ||
| * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
| * with the License. A copy of the License is located at | ||
| * | ||
| * http://aws.amazon.com/apache2.0/ | ||
| * | ||
| * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
| * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
| * and limitations under the License. | ||
| */ | ||
| package ai.djl.nn.norm; | ||
|
|
||
| import ai.djl.Device; | ||
| import ai.djl.MalformedModelException; | ||
| import ai.djl.ndarray.NDArray; | ||
| import ai.djl.ndarray.NDList; | ||
| import ai.djl.ndarray.internal.NDArrayEx; | ||
| import ai.djl.ndarray.types.Shape; | ||
| import ai.djl.nn.AbstractBlock; | ||
| import ai.djl.nn.Parameter; | ||
| import ai.djl.training.ParameterStore; | ||
| import ai.djl.util.PairList; | ||
| import java.io.DataInputStream; | ||
| import java.io.DataOutputStream; | ||
| import java.io.IOException; | ||
| import java.util.Arrays; | ||
|
|
||
| /** | ||
| * Layer normalization works by normalizing the values of input data for each input sample to have | ||
| * mean of 0 and variance of 1. Since this may alter the representation of a layer, two parameters | ||
| * (\ (\gamma\) and \(\beta\)) are learned along the normalization process to respectively scale and | ||
| * shift the normalized output (activations) to have any mean and variance so the network can | ||
| * utilize non-linear transformations such as sigmoid function as described in the <a | ||
| * href="https://arxiv.org/abs/1607.06450">paper</a>. During backpropagation, both \(\gamma\) and | ||
| * \(\beta\) parameters are included following the chain-rule in derivation. | ||
| * | ||
| * <p>Citing the abstract of the paper: "Training state-of-the-art, deep neural networks is | ||
| * computationally expensive. One way to reduce the training time is to normalize the activities of | ||
| * the neurons. A recently introduced technique called batch normalization uses the distribution of | ||
| * the summed input to a neuron over a mini-batch of training cases to compute a mean and variance | ||
| * which are then used to normalize the summed input to that neuron on each training case. This | ||
| * significantly reduces the training time in feed-forward neural networks. However, the effect of | ||
| * batch normalization is dependent on the mini-batch size and it is not obvious how to apply it to | ||
| * recurrent neural networks. In this paper, we transpose batch normalization into layer | ||
| * normalization by computing the mean and variance used for normalization from all of the summed | ||
| * inputs to the neurons in a layer on a single training case. Like batch normalization, we also | ||
| * give each neuron its own adaptive bias and gain which are applied after the normalization but | ||
| * before the non-linearity. Unlike batch normalization, layer normalization performs exactly the | ||
| * same computation at training and test times. It is also straightforward to apply to recurrent | ||
| * neural networks by computing the normalization statistics separately at each time step. Layer | ||
| * normalization is very effective at stabilizing the hidden state dynamics in recurrent networks. | ||
| * Empirically, we show that layer normalization can substantially reduce the training time compared | ||
| * with previously published techniques." | ||
| */ | ||
| public class LayerNorm extends AbstractBlock { | ||
|
|
||
| private static final byte VERSION = 2; | ||
|
|
||
| private float epsilon; | ||
| private Shape normalizedShape; | ||
|
|
||
| private boolean center; | ||
| private boolean scale; | ||
| private int[] axis; | ||
| private Parameter gamma; | ||
| private Parameter beta; | ||
|
|
||
| LayerNorm(Builder builder) { | ||
| super(VERSION); | ||
| epsilon = builder.epsilon; | ||
| scale = builder.scale; | ||
| center = builder.center; | ||
| axis = builder.axis; | ||
|
|
||
| // make gamma trainable if scale | ||
| gamma = | ||
| addParameter( | ||
| Parameter.builder() | ||
| .setName("gamma") | ||
| .setType(Parameter.Type.GAMMA) | ||
| .optRequiresGrad(scale) | ||
| .build()); | ||
| // make beta trainable if center | ||
| beta = | ||
| addParameter( | ||
| Parameter.builder() | ||
| .setName("beta") | ||
| .setType(Parameter.Type.BETA) | ||
| .optRequiresGrad(center) | ||
| .build()); | ||
| } | ||
|
|
||
| /** | ||
| * Applies Layer Normalization with average and variance for each input sample across the axis | ||
| * dimensions. | ||
| * | ||
| * @param input the input {@code NDArray} of shape (batchSize, inputChannel, *), * could be | ||
| * empty, width, (height, width), (depth, height, width) | ||
| * @param normalizedShape dimensions to calculate average and variance from | ||
| * @param gamma gamma weight {@code NDArray} | ||
| * @param beta beta weight {@code NDArray} | ||
| * @param eps a value added to the denominator for numerical stability | ||
| * @return the output {@code NDArray} of shape (batchSize, inputChannel, *), * could be empty, | ||
| * width, (height, width), (depth, height, width) | ||
| */ | ||
| public static NDList layerNorm( | ||
| NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) { | ||
| NDArrayEx ex = input.getNDArrayInternal(); | ||
| return ex.layerNorm(input, normalizedShape, gamma, beta, eps); | ||
| } | ||
|
|
||
| /** | ||
| * Creates a builder to build a {@code LayerNorm}. | ||
| * | ||
| * @return a new builder | ||
| */ | ||
| public static Builder builder() { | ||
| return new Builder(); | ||
| } | ||
|
|
||
| /** {@inheritDoc} */ | ||
| @Override | ||
| protected NDList forwardInternal( | ||
| ParameterStore parameterStore, | ||
| NDList inputs, | ||
| boolean training, | ||
| PairList<String, Object> params) { | ||
| NDArray input = inputs.singletonOrThrow(); | ||
| Device device = input.getDevice(); | ||
| NDArray gammaArr = parameterStore.getValue(gamma, device, training); | ||
| NDArray betaArr = parameterStore.getValue(beta, device, training); | ||
|
|
||
| return layerNorm(input, normalizedShape, gammaArr, betaArr, epsilon); | ||
| } | ||
|
|
||
| /** {@inheritDoc} */ | ||
| @Override | ||
| public Shape[] getOutputShapes(Shape[] inputShapes) { | ||
| return new Shape[] {inputShapes[0]}; | ||
| } | ||
|
|
||
| /** {@inheritDoc} */ | ||
| @Override | ||
| protected void beforeInitialize(Shape... inputShapes) { | ||
| super.beforeInitialize(inputShapes); | ||
| normalizedShape = | ||
| axis == null | ||
| ? inputShapes[0].slice(1) | ||
| : new Shape( | ||
| Arrays.stream(axis) | ||
| .mapToLong(dim -> inputShapes[0].get(dim)) | ||
| .toArray()); | ||
| } | ||
|
|
||
| /** {@inheritDoc} */ | ||
| @Override | ||
| public void prepare(Shape[] inputShapes) { | ||
| gamma.setShape(normalizedShape); | ||
| beta.setShape(normalizedShape); | ||
| } | ||
|
|
||
| /** {@inheritDoc} */ | ||
| @Override | ||
| protected void saveMetadata(DataOutputStream os) throws IOException { | ||
| saveInputShapes(os); | ||
| os.writeInt(normalizedShape.getShape().length); | ||
enpasos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for (int i = 0; i < normalizedShape.getShape().length; i++) { | ||
| os.writeLong(normalizedShape.getShape()[i]); | ||
| } | ||
| } | ||
|
|
||
| /** {@inheritDoc} */ | ||
| @Override | ||
| public void loadMetadata(byte version, DataInputStream is) | ||
| throws IOException, MalformedModelException { | ||
| if (version == VERSION) { | ||
enpasos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| readInputShapes(is); | ||
| } else if (version != 1) { | ||
| throw new MalformedModelException("Unsupported encoding version: " + version); | ||
| } | ||
| long[] shapeRaw = new long[is.readInt()]; | ||
enpasos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for (int i = 0; i < shapeRaw.length; i++) { | ||
| shapeRaw[i] = is.readLong(); | ||
| } | ||
| normalizedShape = new Shape(shapeRaw); | ||
| } | ||
|
|
||
| /** The Builder to construct a {@link LayerNorm}. */ | ||
| public static final class Builder { | ||
|
|
||
| private float epsilon = 1E-5f; | ||
| // private Shape normalizedShape; | ||
| private boolean scale = true; | ||
| private boolean center = true; | ||
| private int[] axis; | ||
|
|
||
| Builder() {} | ||
|
|
||
| /** | ||
| * List the axis over which the mean and variance will be calculated (alternative to | ||
| * normalizedShape). | ||
| * | ||
| * @param axis input axis over which the mean and variance will be calculated (if null all | ||
| * existing dimensions) | ||
| * @return this Builder | ||
| */ | ||
| public Builder axis(int... axis) { | ||
| this.axis = axis; | ||
| return this; | ||
| } | ||
|
|
||
| /** | ||
| * If True, add offset of `beta` to normalized tensor. Defaults to True. | ||
| * | ||
| * @param val True or False on whether to add and train offset value | ||
| * @return this Builder | ||
| */ | ||
| public Builder optCenter(boolean val) { | ||
| center = val; | ||
| return this; | ||
| } | ||
|
|
||
| /** | ||
| * If True, multiply result by `gamma`. Defaults to True; | ||
| * | ||
| * @param val True or False on whether to add and train scale value | ||
| * @return this Builder | ||
| */ | ||
| public Builder optScale(boolean val) { | ||
| scale = val; | ||
| return this; | ||
| } | ||
|
|
||
| /** | ||
| * Sets the epsilon value to prevent division by 0. | ||
| * | ||
| * @param val the epsilon value | ||
| * @return this Builder | ||
| */ | ||
| public Builder optEpsilon(float val) { | ||
| epsilon = val; | ||
| return this; | ||
| } | ||
|
|
||
| /** | ||
| * Builds a {@link LayerNorm} block. | ||
| * | ||
| * @return the {@link LayerNorm} block | ||
| */ | ||
| public LayerNorm build() { | ||
| return new LayerNorm(this); | ||
| } | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -411,6 +411,18 @@ public NDList dropout(NDArray input, float rate, boolean training) { | |
| return new NDList(JniUtils.dropout((PtNDArray) input, rate, training)); | ||
| } | ||
|
|
||
| @Override | ||
enpasos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| public NDList layerNorm( | ||
| NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) { | ||
|
|
||
| return new NDList( | ||
| JniUtils.layerNorm( | ||
| (PtNDArray) input, | ||
| normalizedShape, | ||
| (PtNDArray) gamma, | ||
| (PtNDArray) beta, | ||
| eps)); | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add en empty line here
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have tried to fix it ... but could be that I am blind here ... I am using |
||
| /** {@inheritDoc} */ | ||
| @Override | ||
| public NDList batchNorm( | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.