1+ # Copyright 2022 Big Vision Authors.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ # pylint: disable=line-too-long
16+ r"""Distilling BiT-R152x2 into BiT-R50x1 on Flowers/Pet as in https://arxiv.org/abs/2106.05237
17+
18+ While many epochs are required, this is a small dataset, and thus overall it
19+ is still fast and possible to run on the relatively small v3-8TPUs (or GPUs).
20+
21+ This configuration contains the recommended settings from Fig3/Tab4 of the
22+ paper, which can be selected via the fast/medium/long config argument.
23+ (best settings were selected on a 10% minival)
24+
25+ For Flowers:
26+ - The `fast` variant takes ~1h10m on a v2-8 TPU.
27+ Example logs at gs://big_vision/distill/bit_flowers_fast_06-18_2008/big_vision_metrics.txt
28+ - The `long` variant takes ~25h on a v3-32 TPU.
29+ Example logs at gs://big_vision/distill/bit_flowers_long_06-19_0524/big_vision_metrics.txt
30+ For Pet:
31+ - The `fast` variant takes ~28min on a v2-8 TPU.
32+ Example logs at gs://big_vision/distill/bit_pet_fast_06-16_2338/big_vision_metrics.txt
33+ - The `long` variant takes ~11h on a v2-8 and ~8h on a v3-32.
34+ Example logs at gs://big_vision/distill/bit_pet_long_06-17_0050/big_vision_metrics.txt
35+
36+ big_vision.trainers.proj.distill.distill \
37+ --config big_vision/configs/proj/distill/bigsweep_flowers_pet.py:data=flowers,variant=fast \
38+ --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
39+ """
40+
41+ import big_vision .configs .common as bvcc
42+ import big_vision .configs .proj .distill .common as cd
43+ import ml_collections as mlc
44+
45+ NCLS = dict (flowers = 102 , pet = 37 )
46+
47+
48+ def get_config (arg = None ):
49+ """Config for massive hypothesis-test on pet."""
50+ arg = bvcc .parse_arg (arg , runlocal = False , data = 'flowers' , variant = 'medium' , crop = 'inception_crop(128)' )
51+ config = mlc .ConfigDict ()
52+
53+ config .dataset = dict (flowers = 'oxford_flowers102' , pet = 'oxford_iiit_pet' )[arg .data ]
54+ config .cache_raw = True
55+ config .prefetch_to_device = 4
56+ config .train_split = dict (flowers = 'train' , pet = 'train[:90%]' )[arg .data ]
57+ config .num_classes = NCLS [arg .data ]
58+
59+ config .batch_size = 512
60+ config .num_epochs = {
61+ 'flowers' : {'fast' : 10_000 , 'medium' : 100_000 , 'long' : 1_000_000 },
62+ 'pet' : {'fast' : 1000 , 'medium' : 3000 , 'long' : 30_000 },
63+ }[arg .data ][arg .variant ]
64+ config .shuffle_buffer_size = 50_000
65+
66+ config .log_training_steps = 100
67+ config .checkpoint_steps = 2500
68+
69+ # Model section
70+ config .student_name = 'bit_paper'
71+ config .student = dict (depth = 50 , width = 1 )
72+
73+ config .teachers = ['prof_m' ]
74+ config .prof_m_name = 'bit_paper'
75+ config .prof_m_init = cd .inits [f'BiT-M R152x2 { arg .data } rc128' ]
76+ config .prof_m = dict (depth = 152 , width = 2 )
77+
78+ # Preprocessing pipeline for student & tacher.
79+ pp_common = (
80+ '|value_range(-1, 1)'
81+ f'|onehot({ config .num_classes } , key="label", key_result="labels")'
82+ '|keep("image", "labels")'
83+ )
84+ config .pp_train = f'decode|{ arg .crop } |flip_lr' + pp_common
85+ ppv = 'decode|resize_small(160)|central_crop(128)' + pp_common
86+
87+ config .mixup = dict (p = 1.0 , n = 2 )
88+
89+ # Distillation settings
90+ config .distance = 'kl'
91+ config .distance_kw = dict (t = {
92+ 'flowers' : {'fast' : 10. , 'medium' : 1. , 'long' : 1. },
93+ 'pet' : {'fast' : 5. , 'medium' : 10. , 'long' : 2. },
94+ }[arg .data ][arg .variant ])
95+
96+ # Optimizer section
97+ config .grad_clip_norm = 1.0
98+ config .optax_name = 'scale_by_adam'
99+ config .optax = dict (mu_dtype = 'bfloat16' )
100+
101+ config .lr = {
102+ 'flowers' : {'fast' : 0.003 , 'medium' : 0.001 , 'long' : 0.0003 },
103+ 'pet' : {'fast' : 0.01 , 'medium' : 0.003 , 'long' : 0.003 },
104+ }[arg .data ][arg .variant ]
105+ config .wd = {
106+ 'flowers' : {'fast' : 3e-4 , 'medium' : 1e-4 , 'long' : 1e-5 },
107+ 'pet' : {'fast' : 1e-3 , 'medium' : 3e-4 , 'long' : 1e-5 },
108+ }[arg .data ][arg .variant ]
109+ config .schedule = dict (warmup_steps = 1500 , decay_type = 'cosine' )
110+ config .optim_name = 'adam_hp'
111+
112+ # Eval section
113+ minitrain_split = 'train[:512]' if not arg .runlocal else 'train[:16]'
114+ if arg .data == 'flowers' :
115+ val_split = 'validation' if not arg .runlocal else 'validation[:16]'
116+ test_split = 'test' if not arg .runlocal else 'test[:16]'
117+ elif arg .data == 'pet' :
118+ val_split = 'train[90%:]' if not arg .runlocal else 'train[:16]'
119+ test_split = 'test' if not arg .runlocal else 'test[:16]'
120+
121+ base = dict (
122+ type = 'classification' ,
123+ pred = 'student_fwd' ,
124+ dataset = config .dataset ,
125+ pp_fn = ppv ,
126+ loss_name = 'softmax_xent' ,
127+ log_steps = 500 ,
128+ )
129+ config .evals = {}
130+ config .evals .student_train = {** base , 'split' : minitrain_split }
131+ config .evals .student_val = {** base , 'split' : val_split }
132+ config .evals .student_test = {** base , 'split' : test_split }
133+
134+ # Teacher is fixed, so rare evals.
135+ teacher = dict (log_steps = 100_000 , pred = 'prof_m_fwd' )
136+ config .evals .teacher_train = {** config .evals .student_train , ** teacher }
137+ config .evals .teacher_val = {** config .evals .student_val , ** teacher }
138+ config .evals .teacher_test = {** config .evals .student_test , ** teacher }
139+
140+ # Could in principle also look at agreement on other datasets!
141+ dist = dict (
142+ type = 'proj.distill.distance' ,
143+ pred = 'student_prof_m_fwd' ,
144+ dataset = config .dataset ,
145+ pp_fn = ppv + '|keep("image")' ,
146+ log_steps = 1000 ,
147+ distances = ({'kind' : 'kl' }, {'kind' : 'euclidean' },
148+ {'kind' : 'agree' , 'k' : 1 }, {'kind' : 'agree' , 'k' : 5 }),
149+ )
150+ config .evals .dist_train = {** dist , 'split' : minitrain_split }
151+ config .evals .dist_val = {** dist , 'split' : val_split }
152+ config .evals .dist_test = {** dist , 'split' : test_split }
153+
154+ # Make a few things much smaller for quick local debugging testruns.
155+ if arg .runlocal :
156+ config .shuffle_buffer_size = 10
157+ config .batch_size = 8
158+
159+ return config
0 commit comments