diff --git a/src/Metalhead.jl b/src/Metalhead.jl index 0b278da84..316161859 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -43,13 +43,13 @@ export AlexNet, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, ResNeXt, MobileNetv2, MobileNetv3, - MLPMixer, + MLPMixer, ResMLP, gMLP, ViT, ConvNeXt # use Flux._big_show to pretty print large models for T in (:AlexNet, :VGG, :ResNet, :GoogLeNet, :Inception3, :SqueezeNet, :DenseNet, :ResNeXt, - :MobileNetv2, :MobileNetv3, :MLPMixer, :ViT, :ConvNeXt) + :MobileNetv2, :MobileNetv3, :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvNeXt) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 2c316564a..45a7991d7 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -1,5 +1,5 @@ """ - convnextblock(planes, drop_path = 0., λ = 1f-6) + convnextblock(planes, drop_path_rate = 0., λ = 1f-6) Creates a single block of ConvNeXt. ([reference](https://arxiv.org/abs/2201.03545)) @@ -7,19 +7,16 @@ Creates a single block of ConvNeXt. # Arguments: - `planes`: number of input channels. - `drop_path_rate`: Stochastic depth rate. -- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239) +- `λ`: Init value for LayerScale """ function convnextblock(planes, drop_path_rate = 0., λ = 1f-6) - γ = Flux.ones32(planes) * λ - LayerScale(x) = x .* γ - scale = λ > 0 ? identity : LayerScale layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3), x -> permutedims(x, (3, 1, 2, 4)), LayerNorm(planes; ϵ = 1f-6), mlp_block(planes, 4 * planes), - scale, # LayerScale + LayerScale(planes, λ), x -> permutedims(x, (2, 3, 1, 4)), - Dropout(drop_path_rate, dims = 4)), +) + DropPath(drop_path_rate)), +) return layers end @@ -89,9 +86,12 @@ Creates a ConvNeXt model. - `drop_path_rate`: Stochastic depth rate. - `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239) - `nclasses`: number of output classes + +See also [`Metalhead.convnext`](#). """ function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0., λ = 1f-6, nclasses = 1000) + @assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))" depths = convnext_configs[mode][:depths] planes = convnext_configs[mode][:planes] layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses) diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index fa85fe548..0fc3980b5 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -80,7 +80,7 @@ function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000) end """ - densenet(nblocks; growth_rate = 32, reduction = 0.5, num_classes = 1000) + densenet(nblocks; growth_rate = 32, reduction = 0.5, nclasses = 1000) Create a DenseNet model ([reference](https://arxiv.org/abs/1608.06993)). diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index e208c74cb..d201e39bb 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -131,11 +131,9 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000) # building last several layers output_channel = max_width output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : output_channel - classifier = ( - Dense(explanes, output_channel, hardswish), - Dropout(0.2), - Dense(output_channel, nclasses), - ) + classifier = (Dense(explanes, output_channel, hardswish), + Dropout(0.2), + Dense(output_channel, nclasses)) return Chain(Chain(layers..., conv_bn((1, 1), inplanes, explanes, hardswish, bias = false)...), diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 88d64c94f..4f9eff4ea 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -13,10 +13,12 @@ include("embeddings.jl") include("mlp.jl") include("normalise.jl") include("conv.jl") +include("others.jl") export Attention, MHAttention, PatchEmbedding, ViPosEmbedding, ClassTokens, - mlp_block, + mlp_block, gated_mlp_block, + LayerScale, DropPath, ChannelLayerNorm, prenorm, skip_identity, skip_projection, conv_bn, diff --git a/src/layers/embeddings.jl b/src/layers/embeddings.jl index cdfa479ec..86315b8bd 100644 --- a/src/layers/embeddings.jl +++ b/src/layers/embeddings.jl @@ -1,27 +1,34 @@ """ - PatchEmbedding(patch_size) - PatchEmbedding(patch_height, patch_width) + PatchEmbedding(imsize::NTuple{2} = (224, 224); inchannels = 3, patch_size = (16, 16), + embedplanes = 768, norm_layer = planes -> identity, flatten = true) -Patch embedding layer used by many vision transformer-like models to split the input image into patches. -""" -struct PatchEmbedding - patch_height::Int - patch_width::Int -end +Patch embedding layer used by many vision transformer-like models to split the input image into +patches. -PatchEmbedding(patch_size) = PatchEmbedding(patch_size, patch_size) +# Arguments: +- `imsize`: the size of the input image +- `inchannels`: the number of channels in the input image +- `patch_size`: the size of the patches +- `embedplanes`: the number of channels in the embedding +- `norm_layer`: the normalization layer - by default the identity function but otherwise takes a + single argument constructor for a normalization layer like LayerNorm or BatchNorm +- `flatten`: set true to flatten the input spatial dimensions after the embedding +""" +function PatchEmbedding(imsize::NTuple{2} = (224, 224); inchannels = 3, patch_size = (16, 16), + embedplanes = 768, norm_layer = planes -> identity, flatten = true) -function (p::PatchEmbedding)(x) - h, w, c, n = size(x) - hp, wp = h ÷ p.patch_height, w ÷ p.patch_width - xpatch = reshape(x, hp, p.patch_height, wp, p.patch_width, c, n) + im_height, im_width = imsize + patch_height, patch_width = patch_size - return reshape(permutedims(xpatch, (1, 3, 5, 2, 4, 6)), p.patch_height * p.patch_width * c, - hp * wp, n) + @assert (im_height % patch_height == 0) && (im_width % patch_width == 0) + "Image dimensions must be divisible by the patch size." + + return Chain(Conv(patch_size, inchannels => embedplanes; stride = patch_size), + flatten ? x -> permutedims(reshape(x, (:, size(x, 3), size(x, 4))), (2, 1, 3)) + : identity, + norm_layer(embedplanes)) end -@functor PatchEmbedding - """ ViPosEmbedding(embedsize, npatches; init = (dims) -> rand(Float32, dims)) diff --git a/src/layers/mlp.jl b/src/layers/mlp.jl index 1f5c5640b..a84d3dbe4 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp.jl @@ -1,16 +1,44 @@ """ - mlp_block(planes, hidden_planes; dropout = 0., dense = Dense, activation = gelu) + mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; + dropout = 0., activation = gelu) -Feedforward block used in many vision transformer-like models. +Feedforward block used in many MLPMixer-like and vision-transformer models. # Arguments -- `planes`: Number of dimensions in the input and output. +- `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. +- `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - `dropout`: Dropout rate. -- `dense`: Type of dense layer to use in the feedforward block. - `activation`: Activation function to use. """ -function mlp_block(planes, hidden_planes; dropout = 0., dense = Dense, activation = gelu) - Chain(dense(planes, hidden_planes, activation), Dropout(dropout), - dense(hidden_planes, planes), Dropout(dropout)) +function mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; + dropout = 0., activation = gelu) + Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout), + Dense(hidden_planes, outplanes), Dropout(dropout)) end + +""" + gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer, + outplanes::Integer = inplanes; dropout = 0., activation = gelu) + +Feedforward block based on the implementation in the paper "Pay Attention to MLPs". +([reference](https://arxiv.org/abs/2105.08050)) + +# Arguments +- `gate_layer`: Layer to use for the gating. +- `inplanes`: Number of dimensions in the input. +- `hidden_planes`: Number of dimensions in the intermediate layer. +- `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. +- `dropout`: Dropout rate. +- `activation`: Activation function to use. +""" +function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, + outplanes::Integer = inplanes; dropout = 0., activation = gelu) + @assert hidden_planes % 2 == 0 "`hidden_planes` must be even for gated MLP" + return Chain(Dense(inplanes, hidden_planes, activation), + Dropout(dropout), + gate_layer(hidden_planes), + Dense(hidden_planes ÷ 2, outplanes), + Dropout(dropout)) +end +gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) diff --git a/src/layers/others.jl b/src/layers/others.jl new file mode 100644 index 000000000..ebd4a7c01 --- /dev/null +++ b/src/layers/others.jl @@ -0,0 +1,38 @@ +""" + LayerScale(scale) + +Implements LayerScale. +([reference](https://arxiv.org/abs/2103.17239)) + +# Arguments +- `scale`: Scaling factor, a learnable diagonal matrix which is multiplied to the input. +""" +struct LayerScale{T<:AbstractVector{<:Real}} + scale::T +end + +""" + LayerScale(λ, planes::Int) + +Implements LayerScale. +([reference](https://arxiv.org/abs/2103.17239)) + +# Arguments +- `planes`: Size of channel dimension in the input. +- `λ`: initialisation value for the learnable diagonal matrix. +""" +LayerScale(planes::Int, λ) = λ > 0 ? LayerScale(fill(Float32(λ), planes)) : identity + +@functor LayerScale +(m::LayerScale)(x::AbstractArray) = m.scale .* x + +""" + DropPath(p) + +Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0. +([reference](https://arxiv.org/abs/1603.09382)) + +# Arguments +- `p`: rate of Stochastic Depth. +""" +DropPath(p) = p ≥ 0 ? Dropout(p; dims = 4) : identity \ No newline at end of file diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index ecbcaa2d1..3ef02a3b8 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -1,83 +1,101 @@ -# Utility function for creating a residual block with LayerNorm before the residual connection -_residualprenorm(planes, fn) = SkipConnection(Chain(fn, LayerNorm(planes)), +) +""" + mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, + dropout = 0., drop_path_rate = 0., activation = gelu) -# Utility function for 1D convolution -_conv1d(inplanes, outplanes, activation = identity) = Conv((1, ), inplanes => outplanes, activation) +Creates a feedforward block for the MLPMixer architecture. +([reference](https://arxiv.org/pdf/2105.01601)) +# Arguments: +- `planes`: the number of planes in the block +- `npatches`: the number of patches of the input +- `mlp_ratio`: number(s) that determine(s) the number of hidden channels in the token mixing MLP + and/or the channel mixing MLP as a ratio to the number of planes in the block. +- `mlp_layer`: the MLP layer to use in the block +- `dropout`: the dropout rate to use in the MLP blocks +- `drop_path_rate`: Stochastic depth rate +- `activation`: the activation function to use in the MLP blocks """ - mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512, - depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000, token_mix = - _conv1d, channel_mix = Dense)) +function mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, + dropout = 0., drop_path_rate = 0., activation = gelu) + tokenplanes, channelplanes = [Int(r * planes) for r in mlp_ratio] + return Chain(SkipConnection(Chain(LayerNorm(planes), + x -> permutedims(x, (2, 1, 3)), + mlp_layer(npatches, tokenplanes; activation, dropout), + x -> permutedims(x, (2, 1, 3)), + DropPath(drop_path_rate)), +), + SkipConnection(Chain(LayerNorm(planes), + mlp_layer(planes, channelplanes; activation, dropout), + DropPath(drop_path_rate)), +)) +end + +""" + mlpmixer(block, imsize::NTuple{2} = (224, 224); inchannels = 3, norm_layer = LayerNorm, + patch_size::NTuple{2} = (16, 16), embedplanes = 512, drop_path_rate = 0., + depth = 12, nclasses = 1000, kwargs...) Creates a model with the MLPMixer architecture. ([reference](https://arxiv.org/pdf/2105.01601)). # Arguments +- `block`: the type of mixer block to use in the model - architecture dependent + (a constructor of the form `block(embedplanes, npatches; drop_path_rate, kwargs...)`) - `imsize`: the size of the input image - `inchannels`: the number of input channels +- `norm_layer`: the normalization layer to use in the model - `patch_size`: the size of the patches -- `planes`: the number of channels fed into the main model -- `depth`: the number of blocks in the main model -- `expansion_factor`: the number of channels in each block -- `dropout`: the dropout rate -- `nclasses`: the number of classes in the output -- `token_mix`: the function to use for the token mixing layer -- `channel_mix`: the function to use for the channel mixing layer -""" -function mlpmixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512, - depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000, token_mix = - _conv1d, channel_mix = Dense) - - im_height, im_width = imsize - - @assert (im_height % patch_size) == 0 && (im_width % patch_size == 0) - "image size must be divisible by patch size" - - num_patches = (im_height ÷ patch_size) * (im_width ÷ patch_size) - - layers = [] - push!(layers, PatchEmbedding(patch_size)) - push!(layers, Dense((patch_size ^ 2) * inchannels, planes)) - append!(layers, [Chain(_residualprenorm(planes, mlp_block(num_patches, - expansion_factor * num_patches; - dropout, dense = token_mix)), - _residualprenorm(planes, mlp_block(planes, - expansion_factor * planes; dropout, - dense = channel_mix)),) for _ in 1:depth]) - - classification_head = Chain(_seconddimmean, Dense(planes, nclasses)) - - return Chain(Chain(layers...), classification_head) +- `embedplanes`: the number of channels after the patch embedding (denotes the hidden dimension) +- `drop_path_rate`: Stochastic depth rate +- `depth`: the number of blocks in the model +- `nclasses`: number of output classes +- `kwargs`: additional arguments (if any) to pass to the mixer block. Will use the defaults if + not specified. +""" +function mlpmixer(block, imsize::NTuple{2} = (224, 224); inchannels = 3, norm_layer = LayerNorm, + patch_size::NTuple{2} = (16, 16), embedplanes = 512, drop_path_rate = 0., + depth = 12, nclasses = 1000, kwargs...) + npatches = prod(imsize .÷ patch_size) + dp_rates = LinRange{Float32}(0., drop_path_rate, depth) + layers = Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), + [block(embedplanes, npatches; drop_path_rate = dp_rates[i], kwargs...) + for i in 1:depth]...) + + classification_head = Chain(norm_layer(embedplanes), seconddimmean, Dense(embedplanes, nclasses)) + return Chain(layers, classification_head) end +# Configurations for MLPMixer models +mixer_configs = Dict(:small => Dict(:depth => 8, :planes => 512), + :base => Dict(:depth => 12, :planes => 768), + :large => Dict(:depth => 24, :planes => 1024), + :huge => Dict(:depth => 32, :planes => 1280)) + struct MLPMixer layers end """ - MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512, - depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000) + MLPMixer(size::Symbol = :base; patch_size::Int = 16, imsize::NTuple{2} = (224, 224), + drop_path_rate = 0., nclasses = 1000) Creates a model with the MLPMixer architecture. ([reference](https://arxiv.org/pdf/2105.01601)). # Arguments -- `imsize`: the size of the input image -- `inchannels`: the number of input channels +- `size`: the size of the model - one of `small`, `base`, `large` or `huge` - `patch_size`: the size of the patches -- `planes`: the number of channels fed into the main model -- `depth`: the number of blocks in the main model -- `expansion_factor`: the number of channels in each block -- `dropout`: the dropout rate -- `nclasses`: the number of classes in the output +- `imsize`: the size of the input image +- `drop_path_rate`: Stochastic depth rate +- `nclasses`: number of output classes See also [`Metalhead.mlpmixer`](#). """ -function MLPMixer(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = 16, planes = 512, - depth = 12, expansion_factor = 4, dropout = 0., nclasses = 1000) - - layers = mlpmixer(imsize; inchannels, patch_size, planes, depth, expansion_factor, dropout, - nclasses) +function MLPMixer(size::Symbol = :base; patch_size::Int = 16, imsize::NTuple{2} = (224, 224), + drop_path_rate = 0., nclasses = 1000) + @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" + patch_size = _to_tuple(patch_size) + depth = mixer_configs[size][:depth] + embedplanes = mixer_configs[size][:planes] + layers = mlpmixer(mixerblock, imsize; patch_size, embedplanes, depth, drop_path_rate, nclasses) MLPMixer(layers) end @@ -87,3 +105,182 @@ end backbone(m::MLPMixer) = m.layers[1] classifier(m::MLPMixer) = m.layers[2] + +""" + resmixerblock(planes, npatches; dropout = 0., drop_path_rate = 0., mlp_ratio = 4.0, + activation = gelu, λ = 1e-4) + +Creates a block for the ResMixer architecture. +([reference](https://arxiv.org/abs/2105.03404)). + +# Arguments +- `planes`: the number of planes in the block +- `npatches`: the number of patches of the input +- `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number + of planes in the block +- `mlp_layer`: the MLP block to use +- `dropout`: the dropout rate to use in the MLP blocks +- `drop_path_rate`: Stochastic depth rate +- `activation`: the activation function to use in the MLP blocks +- `λ`: initialisation constant for the LayerScale +""" +function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, + dropout = 0., drop_path_rate = 0., activation = gelu, λ = 1e-4) +return Chain(SkipConnection(Chain(Flux.Diagonal(planes), + x -> permutedims(x, (2, 1, 3)), + Dense(npatches, npatches), + x -> permutedims(x, (2, 1, 3)), + LayerScale(planes, λ), + DropPath(drop_path_rate)), +), + SkipConnection(Chain(Flux.Diagonal(planes), + mlp_layer(planes, Int(mlp_ratio * planes); dropout, activation), + LayerScale(planes, λ), + DropPath(drop_path_rate)), +)) +end + +struct ResMLP + layers +end + +""" + ResMLP(size::Symbol = :base; patch_size::Int = 16, imsize::NTuple{2} = (224, 224), + drop_path_rate = 0., nclasses = 1000) + +Creates a model with the ResMLP architecture. +([reference](https://arxiv.org/abs/2105.03404)). + +# Arguments +- `size`: the size of the model - one of `small`, `base`, `large` or `huge` +- `patch_size`: the size of the patches +- `imsize`: the size of the input image +- `drop_path_rate`: Stochastic depth rate +- `nclasses`: number of output classes + +See also [`Metalhead.mlpmixer`](#). +""" +function ResMLP(size::Symbol = :base; patch_size::Int = 16, imsize::NTuple{2} = (224, 224), + drop_path_rate = 0., nclasses = 1000) + @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" + patch_size = _to_tuple(patch_size) + depth = mixer_configs[size][:depth] + embedplanes = mixer_configs[size][:planes] + layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, embedplanes, + drop_path_rate, depth, nclasses) + ResMLP(layers) +end + +@functor ResMLP + +(m::ResMLP)(x) = m.layers(x) + +backbone(m::ResMLP) = m.layers[1] +classifier(m::ResMLP) = m.layers[2] + +""" + SpatialGatingUnit(norm, proj) + +Creates a spatial gating unit as described in the gMLP paper. +([reference](https://arxiv.org/abs/2105.08050)) + +# Arguments +- `norm`: the normalisation layer to use +- `proj`: the projection layer to use +""" +struct SpatialGatingUnit{T, F} + norm::T + proj::F +end + +""" + SpatialGatingUnit(planes::Int, npatches::Int; norm_layer = LayerNorm) + +Creates a spatial gating unit as described in the gMLP paper. +([reference](https://arxiv.org/abs/2105.08050)) + +# Arguments +- `planes`: the number of planes in the block +- `npatches`: the number of patches of the input +- `norm_layer`: the normalisation layer to use +""" +function SpatialGatingUnit(planes::Int, npatches::Int; norm_layer = LayerNorm) + gateplanes = planes ÷ 2 + norm = norm_layer(gateplanes) + proj = Dense(2 * eps(Float32) .* rand(Float32, npatches, npatches), ones(npatches)) + return SpatialGatingUnit(norm, proj) +end + +@functor SpatialGatingUnit + +function (m::SpatialGatingUnit)(x) + u, v = chunk(x, 2; dims = 1) + v = m.norm(v) + v = m.proj(permutedims(v, (2, 1, 3))) + return u .* permutedims(v, (2, 1, 3)) +end + +""" + spatial_gating_block(planes, npatches; mlp_ratio = 4.0, mlp_layer = gated_mlp_block, + norm_layer = LayerNorm, dropout = 0.0, drop_path_rate = 0., + activation = gelu) + +Creates a feedforward block based on the gMLP model architecture described in the paper. +([reference](https://arxiv.org/abs/2105.08050)) + +# Arguments +- `planes`: the number of planes in the block +- `npatches`: the number of patches of the input +- `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number + of planes in the block +- `norm_layer`: the normalisation layer to use +- `dropout`: the dropout rate to use in the MLP blocks +- `drop_path_rate`: Stochastic depth rate +- `activation`: the activation function to use in the MLP blocks +""" +function spatial_gating_block(planes, npatches; mlp_ratio = 4.0, norm_layer = LayerNorm, + mlp_layer = gated_mlp_block, dropout = 0., drop_path_rate = 0., + activation = gelu) + channelplanes = Int(mlp_ratio * planes) + sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) + return SkipConnection(Chain(norm_layer(planes), + mlp_layer(sgu, planes, channelplanes; activation, dropout), + DropPath(drop_path_rate)), +) +end + +struct gMLP + layers +end + +""" + gMLP(size::Symbol = :base; patch_size::Int = 16, imsize::NTuple{2} = (224, 224), + drop_path_rate = 0., nclasses = 1000) + +Creates a model with the gMLP architecture. +([reference](https://arxiv.org/abs/2105.08050)). + +# Arguments +- `size`: the size of the model - one of `small`, `base`, `large` or `huge` +- `patch_size`: the size of the patches +- `imsize`: the size of the input image +- `drop_path_rate`: Stochastic depth rate +- `nclasses`: number of output classes + +See also [`Metalhead.mlpmixer`](#). +""" +function gMLP(size::Symbol = :base; patch_size::Int = 16, imsize::NTuple{2} = (224, 224), + drop_path_rate = 0., nclasses = 1000) + @assert size in keys(mixer_configs) "`size` must be one of $(keys(mixer_configs))" + patch_size = _to_tuple(patch_size) + depth = mixer_configs[size][:depth] + embedplanes = mixer_configs[size][:planes] + layers = mlpmixer(spatial_gating_block, imsize; mlp_layer = gated_mlp_block, + patch_size, embedplanes, drop_path_rate, depth, nclasses) + + gMLP(layers) +end + +@functor gMLP + +(m::gMLP)(x) = m.layers(x) + +backbone(m::gMLP) = m.layers[1] +classifier(m::gMLP) = m.layers[2] diff --git a/src/utilities.jl b/src/utilities.jl index eddb39507..232e7bf1d 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -1,5 +1,7 @@ +_to_tuple(x::Int) = (x, x) + # Utility function for classifier head of vision transformer-like models -_seconddimmean(x) = dropdims(mean(x, dims = 2); dims = 2) +seconddimmean(x) = dropdims(mean(x, dims = 2); dims = 2) # utility function for making sure that all layers have a channel size divisible by 8 # used by MobileNet variants diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 9cf29552a..b8c8e55f3 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -32,7 +32,7 @@ Creates a Vision Transformer (ViT) model. - `imsize`: image size - `inchannels`: number of input channels - `patch_size`: size of the patches -- `planes`: the number of channels fed into the main model +- `embedplanes`: the number of channels after the patch embedding - `depth`: number of blocks in the transformer - `heads`: number of attention heads in the transformer - `mlpplanes`: number of hidden channels in the MLP block in the transformer @@ -42,29 +42,22 @@ Creates a Vision Transformer (ViT) model. - `pool`: pooling type, either :class or :mean - `nclasses`: number of classes in the output """ -function vit(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = (16, 16), planes = 1024, - depth = 6, heads = 16, mlppanes = 2048, headplanes = 64, dropout = 0.1, emb_dropout = 0.1, - pool = :class, nclasses = 1000) - - im_height, im_width = imsize - patch_height, patch_width = patch_size - - @assert (im_height % patch_height == 0) && (im_width % patch_width == 0) - "Image dimensions must be divisible by the patch size." +function vit(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = (16, 16), + embedplanes = 768, depth = 6, heads = 16, mlpplanes = 2048, headplanes = 64, + dropout = 0.1, emb_dropout = 0.1, pool = :class, nclasses = 1000) + @assert pool in [:class, :mean] "Pool type must be either :class (class token) or :mean (mean pooling)" - - npatches = (im_height ÷ patch_height) * (im_width ÷ patch_width) - patchplanes = inchannels * patch_height * patch_width - return Chain(Chain(PatchEmbedding(patch_height, patch_width), - Dense(patchplanes, planes), - ClassTokens(planes), - ViPosEmbedding(planes, npatches + 1), + npatches = prod(imsize .÷ patch_size) + + return Chain(Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), + ClassTokens(embedplanes), + ViPosEmbedding(embedplanes, npatches + 1), Dropout(emb_dropout), - transformer_encoder(planes, depth, heads, headplanes, mlppanes; dropout), - (pool == :class) ? x -> x[:, 1, :] : _seconddimmean), - Chain(LayerNorm(planes), Dense(planes, nclasses))) + transformer_encoder(embedplanes, depth, heads, headplanes, mlpplanes; dropout), + (pool == :class) ? x -> x[:, 1, :] : seconddimmean), + Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses))) end struct ViT @@ -72,9 +65,9 @@ struct ViT end """ - ViT(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = (16, 16), planes = 1024, - depth = 6, heads = 16, mlppanes = 2048, headplanes = 64, dropout = 0.1, emb_dropout = 0.1, - pool = :class, nclasses = 1000) + ViT(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = (16, 16), + embedplanes = 768, depth = 6, heads = 16, mlpplanes = 2048, headplanes = 64, + dropout = 0.1, emb_dropout = 0.1, pool = :class, nclasses = 1000) Creates a Vision Transformer (ViT) model. ([reference](https://arxiv.org/abs/2010.11929)). @@ -83,7 +76,7 @@ Creates a Vision Transformer (ViT) model. - `imsize`: image size - `inchannels`: number of input channels - `patch_size`: size of the patches -- `planes`: the number of channels fed into the main model +- `embedplanes`: the number of channels after the patch embedding - `depth`: number of blocks in the transformer - `heads`: number of attention heads in the transformer - `mlpplanes`: number of hidden channels in the MLP block in the transformer @@ -92,12 +85,14 @@ Creates a Vision Transformer (ViT) model. - `emb_dropout`: dropout rate for the positional embedding layer - `pool`: pooling type, either :class or :mean - `nclasses`: number of classes in the output + +See also [`Metalhead.vit`](#). """ -function ViT(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = (16, 16), planes = 1024, - depth = 6, heads = 16, mlppanes = 2048, headplanes = 64, - dropout = 0.1, emb_dropout = 0.1, pool = :class, nclasses = 1000) +function ViT(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = (16, 16), + embedplanes = 768, depth = 12, heads = 16, mlpplanes = 3072, headplanes = 64, + dropout = 0.1, emb_dropout = 0.1, pool = :class, nclasses = 1000) - layers = vit(imsize; inchannels, patch_size, planes, depth, heads, mlppanes, headplanes, + layers = vit(imsize; inchannels, patch_size, embedplanes, depth, heads, mlpplanes, headplanes, dropout, emb_dropout, pool, nclasses) ViT(layers) diff --git a/test/convnets.jl b/test/convnets.jl index 13dc2ad5c..f6f60fc4b 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -83,6 +83,8 @@ end @test_skip gradtest(m, rand(Float32, 227, 227, 3, 2)) end +GC.gc() + @testset "DenseNet" begin @testset for model in [DenseNet121, DenseNet161, DenseNet169, DenseNet201] m = model() @@ -126,6 +128,8 @@ end end end +GC.gc() + @testset "ConvNeXt" verbose = true begin @testset for mode in [:tiny, :small, :base, :large, :xlarge] @testset for drop_path_rate in [0.0, 0.5, 0.99] diff --git a/test/other.jl b/test/other.jl index 9911f2c96..2a60c2435 100644 --- a/test/other.jl +++ b/test/other.jl @@ -2,6 +2,34 @@ using Metalhead, Test using Flux @testset "MLPMixer" begin - @test size(MLPMixer()(rand(Float32, 256, 256, 3, 2))) == (1000, 2) - @test_skip gradtest(MLPMixer(), rand(Float32, 256, 256, 3, 2)) -end \ No newline at end of file + @testset for mode in [:small, :base, :large, :huge] + @testset for drop_path_rate in [0.0, 0.5, 0.99] + m = MLPMixer(mode; drop_path_rate) + @test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2) + @test_skip gradtest(m, rand(Float32, 224, 224, 3, 2)) + GC.gc() + end + end +end + +@testset "ResMLP" begin + @testset for mode in [:small, :base, :large, :huge] + @testset for drop_path_rate in [0.0, 0.5, 0.99] + m = ResMLP(mode; drop_path_rate) + @test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2) + @test_skip gradtest(m, rand(Float32, 224, 224, 3, 1)) + GC.gc() + end + end +end + +@testset "gMLP" begin + @testset for mode in [:small, :base, :large, :huge] + @testset for drop_path_rate in [0.0, 0.5, 0.99] + m = gMLP(mode; drop_path_rate) + @test size(m(rand(Float32, 224, 224, 3, 2))) == (1000, 2) + @test_skip gradtest(m, rand(Float32, 224, 224, 3, 2)) + GC.gc() + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 074dfc972..66ddbbaa8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,11 +15,15 @@ end include("convnets.jl") end +GC.gc() + # Other tests @testset verbose = true "Other" begin include("other.jl") end +GC.gc() + # ViT tests @testset verbose = true "ViTs" begin include("vit-based.jl")