|
1 | 1 | """ |
2 | | - mlp_block(planes, hidden_planes; dropout = 0., dense = Dense, activation = gelu) |
| 2 | + mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; |
| 3 | + dropout = 0., activation = gelu) |
3 | 4 |
|
4 | | -Feedforward block used in many vision transformer-like models. |
| 5 | +Feedforward block used in many MLPMixer-like and vision-transformer models. |
5 | 6 |
|
6 | 7 | # Arguments |
7 | | -- `planes`: Number of dimensions in the input and output. |
| 8 | +- `inplanes`: Number of dimensions in the input. |
8 | 9 | - `hidden_planes`: Number of dimensions in the intermediate layer. |
| 10 | +- `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. |
9 | 11 | - `dropout`: Dropout rate. |
10 | | -- `dense`: Type of dense layer to use in the feedforward block. |
11 | 12 | - `activation`: Activation function to use. |
12 | 13 | """ |
13 | | -function mlp_block(planes, hidden_planes; dropout = 0., dense = Dense, activation = gelu) |
14 | | - Chain(dense(planes, hidden_planes, activation), Dropout(dropout), |
15 | | - dense(hidden_planes, planes), Dropout(dropout)) |
| 14 | +function mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; |
| 15 | + dropout = 0., activation = gelu) |
| 16 | + Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout), |
| 17 | + Dense(hidden_planes, outplanes), Dropout(dropout)) |
16 | 18 | end |
| 19 | + |
| 20 | +""" |
| 21 | + gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer, |
| 22 | + outplanes::Integer = inplanes; dropout = 0., activation = gelu) |
| 23 | +
|
| 24 | +Feedforward block based on the implementation in the paper "Pay Attention to MLPs". |
| 25 | +([reference](https://arxiv.org/abs/2105.08050)) |
| 26 | +
|
| 27 | +# Arguments |
| 28 | +- `gate_layer`: Layer to use for the gating. |
| 29 | +- `inplanes`: Number of dimensions in the input. |
| 30 | +- `hidden_planes`: Number of dimensions in the intermediate layer. |
| 31 | +- `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. |
| 32 | +- `dropout`: Dropout rate. |
| 33 | +- `activation`: Activation function to use. |
| 34 | +""" |
| 35 | +function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, |
| 36 | + outplanes::Integer = inplanes; dropout = 0., activation = gelu) |
| 37 | + @assert hidden_planes % 2 == 0 "`hidden_planes` must be even for gated MLP" |
| 38 | + return Chain(Dense(inplanes, hidden_planes, activation), |
| 39 | + Dropout(dropout), |
| 40 | + gate_layer(hidden_planes), |
| 41 | + Dense(hidden_planes ÷ 2, outplanes), |
| 42 | + Dropout(dropout)) |
| 43 | +end |
| 44 | +gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) |
0 commit comments