Skip to content
4 changes: 2 additions & 2 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ export AlexNet,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
ResNeXt,
MobileNetv2, MobileNetv3,
MLPMixer,
MLPMixer, ResMLP, gMLP,
ViT,
ConvNeXt

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt,
:MobileNetv2, :MobileNetv3, :MLPMixer, :ViT, :ConvNeXt)
:MobileNetv2, :MobileNetv3, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvNeXt)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end

Expand Down
14 changes: 7 additions & 7 deletions src/convnets/convnext.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
"""
convnextblock(planes, drop_path = 0., λ = 1f-6)
convnextblock(planes, drop_path_rate = 0., λ = 1f-6)

Creates a single block of ConvNeXt.
([reference](https://arxiv.org/abs/2201.03545))

# Arguments:
- `planes`: number of input channels.
- `drop_path_rate`: Stochastic depth rate.
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
- `λ`: Init value for LayerScale
"""
function convnextblock(planes, drop_path_rate = 0., λ = 1f-6)
γ = Flux.ones32(planes) * λ
LayerScale(x) = x .* γ
scale = λ > 0 ? identity : LayerScale
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
x -> permutedims(x, (3, 1, 2, 4)),
LayerNorm(planes; ϵ = 1f-6),
mlp_block(planes, 4 * planes),
scale, # LayerScale
LayerScale(planes, λ),
x -> permutedims(x, (2, 3, 1, 4)),
Dropout(drop_path_rate, dims = 4)), +)
DropPath(drop_path_rate)), +)
return layers
end

Expand Down Expand Up @@ -89,9 +86,12 @@ Creates a ConvNeXt model.
- `drop_path_rate`: Stochastic depth rate.
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
- `nclasses`: number of output classes

See also [`Metalhead.convnext`](#).
"""
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6,
nclasses = 1000)
@assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))"
depths = convnext_configs[mode][:depths]
planes = convnext_configs[mode][:planes]
layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses)
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/densenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000)
end

"""
densenet(nblocks; growth_rate = 32, reduction = 0.5, num_classes = 1000)
densenet(nblocks; growth_rate = 32, reduction = 0.5, nclasses = 1000)

Create a DenseNet model
([reference](https://arxiv.org/abs/1608.06993)).
Expand Down
8 changes: 3 additions & 5 deletions src/convnets/mobilenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,9 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000)
# building last several layers
output_channel = max_width
output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : output_channel
classifier = (
Dense(explanes, output_channel, hardswish),
Dropout(0.2),
Dense(output_channel, nclasses),
)
classifier = (Dense(explanes, output_channel, hardswish),
Dropout(0.2),
Dense(output_channel, nclasses))

return Chain(Chain(layers...,
conv_bn((1, 1), inplanes, explanes, hardswish, bias = false)...),
Expand Down
4 changes: 3 additions & 1 deletion src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ include("embeddings.jl")
include("mlp.jl")
include("normalise.jl")
include("conv.jl")
include("others.jl")

export Attention, MHAttention,
PatchEmbedding, ViPosEmbedding, ClassTokens,
mlp_block,
mlp_block, gated_mlp_block,
LayerScale, DropPath,
ChannelLayerNorm, prenorm,
skip_identity, skip_projection,
conv_bn,
Expand Down
41 changes: 24 additions & 17 deletions src/layers/embeddings.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,34 @@
"""
PatchEmbedding(patch_size)
PatchEmbedding(patch_height, patch_width)
PatchEmbedding(imsize::NTuple{2} = (224, 224); inchannels = 3, patch_size = (16, 16),
embedplanes = 768, norm_layer = planes -> identity, flatten = true)

Patch embedding layer used by many vision transformer-like models to split the input image into patches.
"""
struct PatchEmbedding
patch_height::Int
patch_width::Int
end
Patch embedding layer used by many vision transformer-like models to split the input image into
patches.

PatchEmbedding(patch_size) = PatchEmbedding(patch_size, patch_size)
# Arguments:
- `imsize`: the size of the input image
- `inchannels`: the number of channels in the input image
- `patch_size`: the size of the patches
- `embedplanes`: the number of channels in the embedding
- `norm_layer`: the normalization layer - by default the identity function but otherwise takes a
single argument constructor for a normalization layer like LayerNorm or BatchNorm
- `flatten`: set true to flatten the input spatial dimensions after the embedding
"""
function PatchEmbedding(imsize::NTuple{2} = (224, 224); inchannels = 3, patch_size = (16, 16),
embedplanes = 768, norm_layer = planes -> identity, flatten = true)

function (p::PatchEmbedding)(x)
h, w, c, n = size(x)
hp, wp = h ÷ p.patch_height, w ÷ p.patch_width
xpatch = reshape(x, hp, p.patch_height, wp, p.patch_width, c, n)
im_height, im_width = imsize
patch_height, patch_width = patch_size

return reshape(permutedims(xpatch, (1, 3, 5, 2, 4, 6)), p.patch_height * p.patch_width * c,
hp * wp, n)
@assert (im_height % patch_height == 0) && (im_width % patch_width == 0)
"Image dimensions must be divisible by the patch size."

return Chain(Conv(patch_size, inchannels => embedplanes; stride = patch_size),
flatten ? x -> permutedims(reshape(x, (:, size(x, 3), size(x, 4))), (2, 1, 3))
: identity,
norm_layer(embedplanes))
end

@functor PatchEmbedding

"""
ViPosEmbedding(embedsize, npatches; init = (dims) -> rand(Float32, dims))

Expand Down
42 changes: 35 additions & 7 deletions src/layers/mlp.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,44 @@
"""
mlp_block(planes, hidden_planes; dropout = 0., dense = Dense, activation = gelu)
mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes;
dropout = 0., activation = gelu)

Feedforward block used in many vision transformer-like models.
Feedforward block used in many MLPMixer-like and vision-transformer models.

# Arguments
- `planes`: Number of dimensions in the input and output.
- `inplanes`: Number of dimensions in the input.
- `hidden_planes`: Number of dimensions in the intermediate layer.
- `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`.
- `dropout`: Dropout rate.
- `dense`: Type of dense layer to use in the feedforward block.
- `activation`: Activation function to use.
"""
function mlp_block(planes, hidden_planes; dropout = 0., dense = Dense, activation = gelu)
Chain(dense(planes, hidden_planes, activation), Dropout(dropout),
dense(hidden_planes, planes), Dropout(dropout))
function mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes;
dropout = 0., activation = gelu)
Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout),
Dense(hidden_planes, outplanes), Dropout(dropout))
end

"""
gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer,
outplanes::Integer = inplanes; dropout = 0., activation = gelu)

Feedforward block based on the implementation in the paper "Pay Attention to MLPs".
([reference](https://arxiv.org/abs/2105.08050))

# Arguments
- `gate_layer`: Layer to use for the gating.
- `inplanes`: Number of dimensions in the input.
- `hidden_planes`: Number of dimensions in the intermediate layer.
- `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`.
- `dropout`: Dropout rate.
- `activation`: Activation function to use.
"""
function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer,
outplanes::Integer = inplanes; dropout = 0., activation = gelu)
@assert hidden_planes % 2 == 0 "`hidden_planes` must be even for gated MLP"
return Chain(Dense(inplanes, hidden_planes, activation),
Dropout(dropout),
gate_layer(hidden_planes),
Dense(hidden_planes ÷ 2, outplanes),
Dropout(dropout))
end
gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...)
38 changes: 38 additions & 0 deletions src/layers/others.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
LayerScale(scale)

Implements LayerScale.
([reference](https://arxiv.org/abs/2103.17239))

# Arguments
- `scale`: Scaling factor, a learnable diagonal matrix which is multiplied to the input.
"""
struct LayerScale{T<:AbstractVector{<:Real}}
scale::T
end

"""
LayerScale(λ, planes::Int)

Implements LayerScale.
([reference](https://arxiv.org/abs/2103.17239))

# Arguments
- `planes`: Size of channel dimension in the input.
- `λ`: initialisation value for the learnable diagonal matrix.
"""
LayerScale(planes::Int, λ) = λ > 0 ? LayerScale(fill(Float32(λ), planes)) : identity

@functor LayerScale
(m::LayerScale)(x::AbstractArray) = m.scale .* x

"""
DropPath(p)

Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0.
([reference](https://arxiv.org/abs/1603.09382))

# Arguments
- `p`: rate of Stochastic Depth.
"""
DropPath(p) = p ≥ 0 ? Dropout(p; dims = 4) : identity
Loading