-
-
Notifications
You must be signed in to change notification settings - Fork 67
Implementation of ResMLP and gMLP (with improvements to MLPMixer and PatchEmbedding)
#125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
f87c1cb
Initial commit for ResMLP
theabhirath ad5c12e
Added GC calls between testsets
theabhirath 2b6ceee
More GC + reduce batch size
theabhirath b5ca55c
Even more GC
theabhirath 1296dc4
Added gMLP
theabhirath c8333a6
Minor Fixes for LayerScale
theabhirath 7d5cdac
Fixes I
theabhirath 8d7a41b
mlp_block API update
theabhirath a3181a4
Fixes II
theabhirath 3617b74
GC as much as humanly possible
theabhirath 6eb93b6
Weight init for gating unit
theabhirath 0189aa9
Cleanup
theabhirath f2250d0
Apply suggestions from code review
theabhirath File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,44 @@ | ||
| """ | ||
| mlp_block(planes, hidden_planes; dropout = 0., dense = Dense, activation = gelu) | ||
| mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; | ||
| dropout = 0., activation = gelu) | ||
|
|
||
| Feedforward block used in many vision transformer-like models. | ||
| Feedforward block used in many MLPMixer-like and vision-transformer models. | ||
|
|
||
| # Arguments | ||
| - `planes`: Number of dimensions in the input and output. | ||
| - `inplanes`: Number of dimensions in the input. | ||
| - `hidden_planes`: Number of dimensions in the intermediate layer. | ||
| - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. | ||
| - `dropout`: Dropout rate. | ||
| - `dense`: Type of dense layer to use in the feedforward block. | ||
| - `activation`: Activation function to use. | ||
| """ | ||
| function mlp_block(planes, hidden_planes; dropout = 0., dense = Dense, activation = gelu) | ||
| Chain(dense(planes, hidden_planes, activation), Dropout(dropout), | ||
| dense(hidden_planes, planes), Dropout(dropout)) | ||
| function mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; | ||
| dropout = 0., activation = gelu) | ||
| Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout), | ||
| Dense(hidden_planes, outplanes), Dropout(dropout)) | ||
| end | ||
|
|
||
| """ | ||
| gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer, | ||
| outplanes::Integer = inplanes; dropout = 0., activation = gelu) | ||
|
|
||
| Feedforward block based on the implementation in the paper "Pay Attention to MLPs". | ||
| ([reference](https://arxiv.org/abs/2105.08050)) | ||
|
|
||
| # Arguments | ||
| - `gate_layer`: Layer to use for the gating. | ||
| - `inplanes`: Number of dimensions in the input. | ||
| - `hidden_planes`: Number of dimensions in the intermediate layer. | ||
| - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. | ||
| - `dropout`: Dropout rate. | ||
| - `activation`: Activation function to use. | ||
| """ | ||
| function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, | ||
| outplanes::Integer = inplanes; dropout = 0., activation = gelu) | ||
| @assert hidden_planes % 2 == 0 "`hidden_planes` must be even for gated MLP" | ||
| return Chain(Dense(inplanes, hidden_planes, activation), | ||
| Dropout(dropout), | ||
| gate_layer(hidden_planes), | ||
| Dense(hidden_planes ÷ 2, outplanes), | ||
| Dropout(dropout)) | ||
| end | ||
| gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| """ | ||
| LayerScale(scale) | ||
|
|
||
| Implements LayerScale. | ||
| ([reference](https://arxiv.org/abs/2103.17239)) | ||
|
|
||
| # Arguments | ||
| - `scale`: Scaling factor, a learnable diagonal matrix which is multiplied to the input. | ||
| """ | ||
| struct LayerScale{T<:AbstractVector{<:Real}} | ||
| scale::T | ||
| end | ||
|
|
||
| """ | ||
| LayerScale(λ, planes::Int) | ||
|
|
||
| Implements LayerScale. | ||
| ([reference](https://arxiv.org/abs/2103.17239)) | ||
|
|
||
| # Arguments | ||
| - `planes`: Size of channel dimension in the input. | ||
| - `λ`: initialisation value for the learnable diagonal matrix. | ||
| """ | ||
| LayerScale(planes::Int, λ) = λ > 0 ? LayerScale(fill(Float32(λ), planes)) : identity | ||
|
|
||
| @functor LayerScale | ||
| (m::LayerScale)(x::AbstractArray) = m.scale .* x | ||
|
|
||
| """ | ||
| DropPath(p) | ||
|
|
||
| Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0. | ||
| ([reference](https://arxiv.org/abs/1603.09382)) | ||
|
|
||
| # Arguments | ||
| - `p`: rate of Stochastic Depth. | ||
| """ | ||
| DropPath(p) = p ≥ 0 ? Dropout(p; dims = 4) : identity | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.