Skip to content

Commit 13cbf02

Browse files
authored
Merge pull request #125 from theabhirath/resmlp
Implementation of ResMLP and gMLP (with improvements to MLPMixer and `PatchEmbedding`)
2 parents 178a0ff + f2250d0 commit 13cbf02

File tree

14 files changed

+428
-125
lines changed

14 files changed

+428
-125
lines changed

src/Metalhead.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ export AlexNet,
4343
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
4444
ResNeXt,
4545
MobileNetv2, MobileNetv3,
46-
MLPMixer,
46+
MLPMixer, ResMLP, gMLP,
4747
ViT,
4848
ConvNeXt
4949

5050
# use Flux._big_show to pretty print large models
5151
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt,
52-
:MobileNetv2, :MobileNetv3, :MLPMixer, :ViT, :ConvNeXt)
52+
:MobileNetv2, :MobileNetv3, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvNeXt)
5353
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
5454
end
5555

src/convnets/convnext.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,22 @@
11
"""
2-
convnextblock(planes, drop_path = 0., λ = 1f-6)
2+
convnextblock(planes, drop_path_rate = 0., λ = 1f-6)
33
44
Creates a single block of ConvNeXt.
55
([reference](https://arxiv.org/abs/2201.03545))
66
77
# Arguments:
88
- `planes`: number of input channels.
99
- `drop_path_rate`: Stochastic depth rate.
10-
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
10+
- `λ`: Init value for LayerScale
1111
"""
1212
function convnextblock(planes, drop_path_rate = 0., λ = 1f-6)
13-
γ = Flux.ones32(planes) * λ
14-
LayerScale(x) = x .* γ
15-
scale = λ > 0 ? identity : LayerScale
1613
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
1714
x -> permutedims(x, (3, 1, 2, 4)),
1815
LayerNorm(planes; ϵ = 1f-6),
1916
mlp_block(planes, 4 * planes),
20-
scale, # LayerScale
17+
LayerScale(planes, λ),
2118
x -> permutedims(x, (2, 3, 1, 4)),
22-
Dropout(drop_path_rate, dims = 4)), +)
19+
DropPath(drop_path_rate)), +)
2320
return layers
2421
end
2522

@@ -89,9 +86,12 @@ Creates a ConvNeXt model.
8986
- `drop_path_rate`: Stochastic depth rate.
9087
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
9188
- `nclasses`: number of output classes
89+
90+
See also [`Metalhead.convnext`](#).
9291
"""
9392
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6,
9493
nclasses = 1000)
94+
@assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))"
9595
depths = convnext_configs[mode][:depths]
9696
planes = convnext_configs[mode][:planes]
9797
layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses)

src/convnets/densenet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000)
8080
end
8181

8282
"""
83-
densenet(nblocks; growth_rate = 32, reduction = 0.5, num_classes = 1000)
83+
densenet(nblocks; growth_rate = 32, reduction = 0.5, nclasses = 1000)
8484
8585
Create a DenseNet model
8686
([reference](https://arxiv.org/abs/1608.06993)).

src/convnets/mobilenet.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,9 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000)
131131
# building last several layers
132132
output_channel = max_width
133133
output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : output_channel
134-
classifier = (
135-
Dense(explanes, output_channel, hardswish),
136-
Dropout(0.2),
137-
Dense(output_channel, nclasses),
138-
)
134+
classifier = (Dense(explanes, output_channel, hardswish),
135+
Dropout(0.2),
136+
Dense(output_channel, nclasses))
139137

140138
return Chain(Chain(layers...,
141139
conv_bn((1, 1), inplanes, explanes, hardswish, bias = false)...),

src/layers/Layers.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ include("embeddings.jl")
1313
include("mlp.jl")
1414
include("normalise.jl")
1515
include("conv.jl")
16+
include("others.jl")
1617

1718
export Attention, MHAttention,
1819
PatchEmbedding, ViPosEmbedding, ClassTokens,
19-
mlp_block,
20+
mlp_block, gated_mlp_block,
21+
LayerScale, DropPath,
2022
ChannelLayerNorm, prenorm,
2123
skip_identity, skip_projection,
2224
conv_bn,

src/layers/embeddings.jl

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,34 @@
11
"""
2-
PatchEmbedding(patch_size)
3-
PatchEmbedding(patch_height, patch_width)
2+
PatchEmbedding(imsize::NTuple{2} = (224, 224); inchannels = 3, patch_size = (16, 16),
3+
embedplanes = 768, norm_layer = planes -> identity, flatten = true)
44
5-
Patch embedding layer used by many vision transformer-like models to split the input image into patches.
6-
"""
7-
struct PatchEmbedding
8-
patch_height::Int
9-
patch_width::Int
10-
end
5+
Patch embedding layer used by many vision transformer-like models to split the input image into
6+
patches.
117
12-
PatchEmbedding(patch_size) = PatchEmbedding(patch_size, patch_size)
8+
# Arguments:
9+
- `imsize`: the size of the input image
10+
- `inchannels`: the number of channels in the input image
11+
- `patch_size`: the size of the patches
12+
- `embedplanes`: the number of channels in the embedding
13+
- `norm_layer`: the normalization layer - by default the identity function but otherwise takes a
14+
single argument constructor for a normalization layer like LayerNorm or BatchNorm
15+
- `flatten`: set true to flatten the input spatial dimensions after the embedding
16+
"""
17+
function PatchEmbedding(imsize::NTuple{2} = (224, 224); inchannels = 3, patch_size = (16, 16),
18+
embedplanes = 768, norm_layer = planes -> identity, flatten = true)
1319

14-
function (p::PatchEmbedding)(x)
15-
h, w, c, n = size(x)
16-
hp, wp = h ÷ p.patch_height, w ÷ p.patch_width
17-
xpatch = reshape(x, hp, p.patch_height, wp, p.patch_width, c, n)
20+
im_height, im_width = imsize
21+
patch_height, patch_width = patch_size
1822

19-
return reshape(permutedims(xpatch, (1, 3, 5, 2, 4, 6)), p.patch_height * p.patch_width * c,
20-
hp * wp, n)
23+
@assert (im_height % patch_height == 0) && (im_width % patch_width == 0)
24+
"Image dimensions must be divisible by the patch size."
25+
26+
return Chain(Conv(patch_size, inchannels => embedplanes; stride = patch_size),
27+
flatten ? x -> permutedims(reshape(x, (:, size(x, 3), size(x, 4))), (2, 1, 3))
28+
: identity,
29+
norm_layer(embedplanes))
2130
end
2231

23-
@functor PatchEmbedding
24-
2532
"""
2633
ViPosEmbedding(embedsize, npatches; init = (dims) -> rand(Float32, dims))
2734

src/layers/mlp.jl

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,44 @@
11
"""
2-
mlp_block(planes, hidden_planes; dropout = 0., dense = Dense, activation = gelu)
2+
mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes;
3+
dropout = 0., activation = gelu)
34
4-
Feedforward block used in many vision transformer-like models.
5+
Feedforward block used in many MLPMixer-like and vision-transformer models.
56
67
# Arguments
7-
- `planes`: Number of dimensions in the input and output.
8+
- `inplanes`: Number of dimensions in the input.
89
- `hidden_planes`: Number of dimensions in the intermediate layer.
10+
- `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`.
911
- `dropout`: Dropout rate.
10-
- `dense`: Type of dense layer to use in the feedforward block.
1112
- `activation`: Activation function to use.
1213
"""
13-
function mlp_block(planes, hidden_planes; dropout = 0., dense = Dense, activation = gelu)
14-
Chain(dense(planes, hidden_planes, activation), Dropout(dropout),
15-
dense(hidden_planes, planes), Dropout(dropout))
14+
function mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes;
15+
dropout = 0., activation = gelu)
16+
Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout),
17+
Dense(hidden_planes, outplanes), Dropout(dropout))
1618
end
19+
20+
"""
21+
gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer,
22+
outplanes::Integer = inplanes; dropout = 0., activation = gelu)
23+
24+
Feedforward block based on the implementation in the paper "Pay Attention to MLPs".
25+
([reference](https://arxiv.org/abs/2105.08050))
26+
27+
# Arguments
28+
- `gate_layer`: Layer to use for the gating.
29+
- `inplanes`: Number of dimensions in the input.
30+
- `hidden_planes`: Number of dimensions in the intermediate layer.
31+
- `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`.
32+
- `dropout`: Dropout rate.
33+
- `activation`: Activation function to use.
34+
"""
35+
function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer,
36+
outplanes::Integer = inplanes; dropout = 0., activation = gelu)
37+
@assert hidden_planes % 2 == 0 "`hidden_planes` must be even for gated MLP"
38+
return Chain(Dense(inplanes, hidden_planes, activation),
39+
Dropout(dropout),
40+
gate_layer(hidden_planes),
41+
Dense(hidden_planes ÷ 2, outplanes),
42+
Dropout(dropout))
43+
end
44+
gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...)

src/layers/others.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
LayerScale(scale)
3+
4+
Implements LayerScale.
5+
([reference](https://arxiv.org/abs/2103.17239))
6+
7+
# Arguments
8+
- `scale`: Scaling factor, a learnable diagonal matrix which is multiplied to the input.
9+
"""
10+
struct LayerScale{T<:AbstractVector{<:Real}}
11+
scale::T
12+
end
13+
14+
"""
15+
LayerScale(λ, planes::Int)
16+
17+
Implements LayerScale.
18+
([reference](https://arxiv.org/abs/2103.17239))
19+
20+
# Arguments
21+
- `planes`: Size of channel dimension in the input.
22+
- `λ`: initialisation value for the learnable diagonal matrix.
23+
"""
24+
LayerScale(planes::Int, λ) = λ > 0 ? LayerScale(fill(Float32(λ), planes)) : identity
25+
26+
@functor LayerScale
27+
(m::LayerScale)(x::AbstractArray) = m.scale .* x
28+
29+
"""
30+
DropPath(p)
31+
32+
Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0.
33+
([reference](https://arxiv.org/abs/1603.09382))
34+
35+
# Arguments
36+
- `p`: rate of Stochastic Depth.
37+
"""
38+
DropPath(p) = p 0 ? Dropout(p; dims = 4) : identity

0 commit comments

Comments
 (0)