This repository contains our efforts to replicate the fine-tuning and training experiments of the MLP-Mixer architecture from the NeurIPS 2021 paper:
MLP-Mixer: An all-MLP Architecture for Vision
Tolstikhin et al., NeurIPS 2021
https://proceedings.neurips.cc/paper/2021/hash/cba0a4ee5ccd02fda0fe3f9a3e7b89fe-Abstract.html
We forked the original code from Google Research’s vision_transformer and added configuration files to run 9 fine-tuning experiments on CIFAR-10 and ImageNet-1K.
├── .github/ ← CI workflows
├── vit\_jax/ ← JAX implementation
│ ├── configs/ ← Experiment configs
│ └── … ← Requirements, Source code, data loaders, training scripts
├── .gitignore
├── README.md
└── LICENSE
We provide config files for different experiments. All live under vit_jax/configs/.
| Pretraining Dataset | Dataset | Model | Config file |
|---|---|---|---|
| ImageNet-1k | CIFAR-10 | Mixer-B/16 | mixer_base16_cifar10_imagenet1k.py |
| ImageNet-1k | CIFAR-10 | Mixer-L/16 | mixer_large16_cifar10_imagenet1k.py |
| ImageNet-21k | CIFAR-10 | Mixer-B/16 | mixer_base16_cifar10_imagenet21k.py |
| ImageNet-21k | CIFAR-10 | Mixer-L/16 | mixer_large16_cifar10_imagenet21k.py |
| ImageNet-21k | CIFAR-100 | Mixer-B/16 | mixer_base16_cifar100.py |
| ImageNet-21k | ImageNet-1K | Mixer-B/16 | mixer_base16_imagenet1k_imagenet21k.py |
| ImageNet-21k | ImageNet-1K | Mixer-L/16 | mixer_large16_imagenet1k_imagenet21k.py |
| Dataset | Model | Config file |
|---|---|---|
| CIFAR-10 | Mixer-B/16 | mixer_base16_cifar10_nopretrain.py |
| CIFAR-10 | Mixer-L/16 | mixer_large16_cifar10_nopretrain.py |
| ImageNet-1K | Mixer-B/16 | mixer_base16_imagenet1k-nopretrain.py |
| ImageNet-1K | Mixer-L/16 | mixer_large16_imagenet1k-nopretrain.py |
-
Hardware: 3× NVIDIA 4090 (or equivalent)
-
OS: Linux
-
Python: 3.8–3.10
-
Dependencies:
pip install -r vit_jax/requirements.txt pip install tensorflow-datasets jax jaxlib # for data loading
-
CIFAR-10 & CIFAR-100
python3 - <<'EOF' import tensorflow_datasets as tfds tfds.load('cifar10', data_dir='$HOME/tensorflow_datasets', download=True) tfds.load('cifar100', data_dir='$HOME/tensorflow_datasets', download=True) EOF
-
ImageNet-1K
wget --auth-no-challenge \ --user=<USERNAME> --password=<PASSWORD> \ -P /devkit \ https://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz python /home/$USER/tensorflow_datasets/imagenet/unzip.py \ --data_root=$HOME/tensorflow_datasets/imagenet \ --devkit_path=/devkit/ILSVRC2012_devkit_t12.tar.gz
Replace $(pwd) with your project root if needed.
# Dataset: CIFAR10, Pretrain:imagenet21k, Model: mider_b16, category: logs/mixer_b16_cifar10_ft_3gpu
CUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_cifar10_ft/mixer_b16_cifar10_ft_3gpu_imagenet21k" \
--config="$(pwd)/vit_jax/configs/mixer_base16_cifar10.py" \
--config.pretrained_dir="gs://mixer_models/imagenet21k" \
--config.batch=510 \
--config.batch_eval=510\
--config.accum_steps=5\
> $(pwd)/logs/mixer_cifar10_ft/mixer_b16_cifar10_ft_3gpu_imagenet21k/mylog.log 2>&1 & disownOR
# Dataset: CIFAR10, Pretrain:imagenet21k, Model: mider_l16, category: logs/mixer_cifar10_ft_btsize510/mixer_b16_cifar10_ft_3gpu_imagenet21k
CUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_cifar10_ft_btsize510/mixer_b16_cifar10_ft_3gpu_imagenet21k" \
--config="$(pwd)/vit_jax/configs/mixer_base16_cifar10-imagenet21k.py" \
--config.pretrained_dir="$(pwd)/pretrained"\
--config.batch=120 \
--config.batch_eval=120 \
--config.accum_steps=4 \
> $(pwd)/logs/mixer_cifar10_ft_btsize510/mixer_b16_cifar10_ft_3gpu_imagenet21k/mylog.log 2>&1 & disown# Dataset: CIFAR10, Pretrain:imagenet21k, Model: mider_l16, category: logs/mixer_l16_cifar10_ft_3gpu
CUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_cifar10_ft_btsize510/mixer_l16_cifar10_ft_3gpu_imagenet21k" \
--config="$(pwd)/vit_jax/configs/mixer_large16_cifar10-imagenet21k.py" \
--config.pretrained_dir="gs://mixer_models/imagenet21k" \
--config.batch=120 \
--config.batch_eval=120 \
--config.accum_steps=4 \
> $(pwd)/logs/mixer_cifar10_ft_btsize510/mixer_l16_cifar10_ft_3gpu_imagenet21k/mylog.log 2>&1 & disownOR
# Dataset: CIFAR10, Pretrain:imagenet21k, Model: mider_l16, category: logs/mixer_cifar10_ft_btsize510/mixer_l16_cifar10_ft_3gpu_imagenet21k
CUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_cifar10_ft_btsize510/mixer_l16_cifar10_ft_3gpu_imagenet21k" \
--config="$(pwd)/vit_jax/configs/mixer_large16_cifar10-imagenet21k.py.py" \
--config.pretrained_dir="$(pwd)/pretrained"\
--config.batch=120 \
--config.batch_eval=120 \
--config.accum_steps=4 \
> $(pwd)/logs/mixer_cifar10_ft_btsize510/mixer_l16_cifar10_ft_3gpu_imagenet21k/mylog.log 2>&1 & disown# Dataset: CIFAR10, Pretrain:imagenet1k, Model: mider_b16, category: logs/mixer_b16_cifar10_ft_3gpu_up_imagenet
CUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_cifar10_ft/mixer_b16_cifar10_ft_3gpu_up_imagenet" \
--config="$(pwd)/vit_jax/configs/mixer_base16_cifar10.py" \
--config.pretrained_dir="gs://mixer_models/imagenet1k" \
--config.batch=510 \
--config.batch_eval=510\
--config.accum_steps=5\
> $(pwd)/logs/mixer_cifar10_ft/mixer_b16_cifar10_ft_3gpu_up_imagenet/mylog.log 2>&1 & disownOR
# Dataset: CIFAR10, Pretrain:imagenet1k, Model: mider_b16, category: logs/mixer_cifar10_ft_btsize510/mixer_b16_cifar10_ft_3gpu_up_imagenet
L16 cifar10----imagenet
CUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_cifar10_ft_btsize510/mixer_b16_cifar10_ft_3gpu_imagenet" \
--config="$(pwd)/vit_jax/configs/mixer_base16_cifar10-imagenet1k.py" \
--config.pretrained_dir="$(pwd)/pretrained"\
--config.batch=120 \
--config.batch_eval=120\
--config.accum_steps=4\
> $(pwd)/logs/mixer_cifar10_ft_btsize510/mixer_b16_cifar10_ft_3gpu_imagenet/mylog.log 2>&1 & disown# Dataset: CIFAR10, Pretrain:imagenet1k, Model: mider_l16, category: logs/mixer_l16_cifar10_ft_3gpu_up_imagenet
L16 cifar10----imagenet
CUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_cifar10_ft/mixer_l16_cifar10_ft_3gpu_up_imagenet" \
--config="$(pwd)/vit_jax/configs/mixer_large16_cifar10.py" \
--config.pretrained_dir="gs://mixer_models/imagenet1k" \
--config.batch=510 \
--config.batch_eval=510\
--config.accum_steps=5\
> $(pwd)/logs/mixer_cifar10_ft/mixer_l16_cifar10_ft_3gpu_up_imagenet/mylog.log 2>&1 & disownOR
# Dataset: CIFAR10, Pretrain:imagenet1k, Model: mider_l16, category: logs/mixer_cifar10_ft_btsize510/mixer_l16_cifar10_ft_3gpu_up_imagenet
L16 cifar10----imagenet
CUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_cifar10_ft_btsize510/mixer_l16_cifar10_ft_3gpu_up_imagenet" \
--config="$(pwd)/vit_jax/configs/mixer_large16_cifar10-imagenet1k.py" \
--config.pretrained_dir="$(pwd)/pretrained"\
--config.batch=120 \
--config.batch_eval=120\
--config.accum_steps=4\
> $(pwd)/logs/mixer_cifar10_ft_btsize510/mixer_l16_cifar10_ft_3gpu_up_imagenet/mylog.log 2>&1 & disownNote: the paper recommends
base_lr = 0.003for ImageNet-1K fine-tuning; we observed some instability (accuracy dips around 61 %), likely due to randomness in initialization.
# Dataset: imagenet1k, Pretrain:imagenet21k, Model: mider_b16, category: logs/mixer_b16_imagenet_ft_3gpu_imagenet21k
CUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_imagenet_ft/mixer_b16_imagenet_ft_3gpu_imagenet21k" \
--config="$(pwd)/vit_jax/configs/mixer_base16_imagenet1k.py" \
--config.dataset="/home/dddddddd/tensorflow_datasets/imagenet" \
--config.pretrained_dir="gs://mixer_models/imagenet21k" \
--config.batch=510 \
--config.batch_eval=510\
--config.accum_steps=5\
> $(pwd)/logs/mixer_imagenet_ft/mixer_b16_imagenet_ft_3gpu_imagenet21k/mylog.log 2>&1 & disownOR
# Dataset: imagenet1k, Pretrain:imagenet21k, Model: mider_b16, category: logs/mixer_imagenet_ft_btsize510/mixer_b16_imagenet_ft_3gpu_imagenet21k
CUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_imagenet_ft_btsize510/mixer_b16_imagenet_ft_3gpu_imagenet21k" \
--config="$(pwd)/vit_jax/configs/mixer_base16_imagenet1k_imagenet21k.py" \
--config.dataset="/home/dddddddd/tensorflow_datasets/imagenet" \
--config.pretrained_dir="$(pwd)/pretrained"\
--config.batch=120 \
--config.batch_eval=120\
--config.accum_steps=4\
> $(pwd)/logs/mixer_imagenet_ft_btsize510/mixer_b16_imagenet_ft_3gpu_imagenet21k/mylog.log 2>&1 & disown# Dataset: imagenet1k, Pretrain:imagenet21k, Model: mider_l16, category: logs/mixer_l16_imagenet_ft_3gpu_imagenet21k
CUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_imagenet_ft/mixer_l16_imagenet_ft_3gpu_imagenet21k" \
--config="$(pwd)/vit_jax/configs/mixer_large16_imagenet1k.py" \
--config.dataset="/home/dddddddd/tensorflow_datasets/imagenet" \
--config.pretrained_dir="gs://mixer_models/imagenet21k" \
--config.batch=510 \
--config.batch_eval=510\
--config.accum_steps=5\
> $(pwd)/logs/mixer_imagenet_ft/mixer_l16_imagenet_ft_3gpu_imagenet21k/mylog.log 2>&1 & disownOR
# Dataset: imagenet1k, Pretrain:imagenet21k, Model: mider_l16, category: logs/mixer_imagenet_ft_btsize510/mixer_l16_imagenet_ft_3gpu_imagenet21k
CUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_imagenet_ft_btsize510/mixer_l16_imagenet_ft_3gpu_imagenet21k" \
--config="$(pwd)/vit_jax/configs/mixer_large16_imagenet1k_imagenet21k.py" \
--config.dataset="/home/dddddddd/tensorflow_datasets/imagenet" \
--config.pretrained_dir="$(pwd)/pretrained"\
--config.batch=120 \
--config.batch_eval=120\
--config.accum_steps=4\
> $(pwd)/logs/mixer_imagenet_ft_btsize510/mixer_l16_imagenet_ft_3gpu_imagenet21k/mylog.log 2>&1 & disownCUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_b16_cifar10_training" \
--config="$(pwd)/vit_jax/configs/mixer_base16_cifar10_nopretrain.py" \
--config.batch=510 \
--config.batch_eval=510
> $(pwd)/logs/mixer_b16_cifar10_training/mylog.log 2>&1 & disownsCUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_l16_cifar10_training" \
--config="$(pwd)/vit_jax/configs/mixer_large16_cifar10_nopretrain.py" \
--config.batch=510 \
--config.batch_eval=510
> $(pwd)/logs/mixer_l16_cifar10_training/mylog.log 2>&1 & disownsCUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_imagenet_training/mixer_b16_imagenet" \
--config="$(pwd)/vit_jax/configs/mixer_base16_imagenet1k-nopretrain.py" \
--config.dataset="/home/dddddddd/tensorflow_datasets/imagenet" \
--config.batch=510 \
--config.batch_eval=510\
> $(pwd)/logs/mixer_imagenet_training/mixer_b16_imagenet/mylog.log 2>&1 & disownCUDA_VISIBLE_DEVICES=0,1,2 python3 -m vit_jax.main \
--workdir="$(pwd)/logs/mixer_imagenet_training/mixer_l16_imagenet" \
--config="$(pwd)/vit_jax/configs/mixer_large16_imagenet1k-nopretrain.py" \
--config.dataset="/home/dddddddd/tensorflow_datasets/imagenet" \
--config.batch=510 \
--config.batch_eval=510\
> $(pwd)/logs/mixer_imagenet_training/mixer_l16_imagenet/mylog.log 2>&1 & disownbatchsize=512 total_steps=1000
| upstream | model | dataset | top-1 accuracy |
|---|---|---|---|
| ImageNet | Mixer-B/16 | cifar10 | 96.72% |
| ImageNet | Mixer-L/16 | cifar10 | 96.59% |
| ImageNet-21k | Mixer-B/16 | cifar10 | 96.82% |
| ImageNet-21k | Mixer-L/16 | cifar10 | 98.34% |
batchsize=510*5=2550 total_steps=10000
| upstream | model | dataset | top-1 accuracy |
|---|---|---|---|
| ImageNet-21k | Mixer-L/16 | cifar10 | 98.29% |
We did the upper experiment because the 1:1 comparason below gets lower accuracy than that in the paper. We find that longer total_steps may get limited higher test accuracy.
batchsize=120*4=480 total_steps=1000
| upstream | model | dataset | top-1 accuracy |
|---|---|---|---|
| ImageNet | Mixer-B/16 | cifar10 | 94.78% |
| ImageNet | Mixer-L/16 | cifar10 | 95.79% |
| ImageNet-21k | Mixer-B/16 | cifar10 | 96.44% |
| ImageNet-21k | Mixer-L/16 | cifar10 | 98.18% |
batchsize=512 total_steps=125000
| upstream | model | dataset | top-1 accuracy |
|---|---|---|---|
| ImageNet-21k | Mixer-B/16 | ImageNet-1k | 80.64% |
| ImageNet-21k | Mixer-L/16 | ImageNet-1k | 82.89% |
batchsize=510*5=2550 total_steps=125000
| upstream | model | dataset | top-1 accuracy |
|---|---|---|---|
| ImageNet-21k | Mixer-B/16 | ImageNet-1k | 78.47% |
batchsize=510*5=2550 total_steps=35000
| upstream | model | dataset | top-1 accuracy |
|---|---|---|---|
| ImageNet-21k | Mixer-L/16 | ImageNet-1k | 81.84% |
We did the upper 2 experiments because the 1:1 comparason below gets lower accuracy than that in the paper. We find that longer total_steps may get limited higher test accuracy, but lower than that in the paper.
batchsize=120*4=480 total_steps=50000
| upstream | model | dataset | top-1 accuracy |
|---|---|---|---|
| ImageNet-21k | Mixer-B/16 | ImageNet-1k | 77.57% |
| ImageNet-21k | Mixer-L/16 | ImageNet-1k | 81.44% |
This code derives from Google Research’s vision_transformer (Apache 2.0). See LICENSE for details.