Skip to content

Commit 8921d51

Browse files
authored
Updates input pipelines & adds LiT-B16B_2 config. (#15)
1 parent 1c6f5aa commit 8921d51

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1196
-1017
lines changed

big_vision/configs/bit_i1k.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,23 @@ def get_config(runlocal=False):
3131
"""Config for training on ImageNet-1k."""
3232
config = mlc.ConfigDict()
3333

34-
config.dataset = 'imagenet2012'
35-
config.train_split = 'train[:99%]'
36-
config.cache_raw = not runlocal # Needs up to 120GB of RAM!
37-
config.shuffle_buffer_size = 250_000 if not runlocal else 10_000 # Per host.
34+
config.seed = 0
35+
config.total_epochs = 90
3836
config.num_classes = 1000
3937
config.loss = 'softmax_xent'
4038

41-
config.seed = 0
42-
config.batch_size = 4096 if not runlocal else 32
43-
config.total_epochs = 90
39+
config.input = dict()
40+
config.input.data = dict(
41+
name='imagenet2012',
42+
split='train[:99%]',
43+
)
44+
config.input.batch_size = 4096 if not runlocal else 32
45+
config.input.cache_raw = not runlocal # Needs up to 120GB of RAM!
46+
config.input.shuffle_buffer_size = 250_000 if not runlocal else 10_000 # Per host.
4447

4548
pp_common = '|onehot(1000, key="{lbl}", key_result="labels")'
4649
pp_common += '|value_range(-1, 1)|keep("image", "labels")'
47-
config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common.format(lbl='label')
50+
config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common.format(lbl='label')
4851
pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
4952

5053
config.log_training_steps = 50
@@ -62,30 +65,29 @@ def get_config(runlocal=False):
6265
config.grad_clip_norm = 1.0
6366

6467
# linear scaling rule. Don't forget to sweep if sweeping batch_size.
65-
config.wd = (1e-4 / 256) * config.batch_size
66-
config.lr = (0.1 / 256) * config.batch_size
68+
config.wd = (1e-4 / 256) * config.input.batch_size
69+
config.lr = (0.1 / 256) * config.input.batch_size
6770
config.schedule = dict(decay_type='cosine', warmup_steps=1000)
6871

6972
# Eval section
70-
eval_common = dict(
71-
type='classification',
72-
dataset='imagenet2012',
73-
pp_fn=pp_eval.format(lbl='label'),
74-
loss_name=config.loss,
75-
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
76-
)
73+
def get_eval(split, dataset='imagenet2012'):
74+
return dict(
75+
type='classification',
76+
data=dict(name=dataset, split=split),
77+
pp_fn=pp_eval.format(lbl='label'),
78+
loss_name=config.loss,
79+
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
80+
cache_final=not runlocal,
81+
)
7782
config.evals = {}
78-
config.evals.train = {**eval_common, 'split': 'train[:2%]'}
79-
config.evals.minival = {**eval_common, 'split': 'train[99%:]'}
80-
config.evals.val = {**eval_common, 'split': 'validation'}
81-
config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'}
82-
83-
config.evals.real = dict(**eval_common)
84-
config.evals.real.dataset = 'imagenet2012_real'
85-
config.evals.real.split = 'validation'
83+
config.evals.train = get_eval('train[:2%]')
84+
config.evals.minival = get_eval('train[99%:]')
85+
config.evals.val = get_eval('validation')
86+
config.evals.v2 = get_eval('test', dataset='imagenet_v2')
87+
config.evals.real = get_eval('validation', dataset='imagenet2012_real')
8688
config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
8789

88-
# config.fewshot = get_fewshot_lsr()
89-
# config.fewshot.log_steps = 1000
90+
# config.evals.fewshot = get_fewshot_lsr(runlocal=runlocal)
91+
# config.evals.fewshot.log_steps = 1000
9092

9193
return config

big_vision/configs/bit_i21k.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,25 @@ def get_config():
2828
"""Config for training on imagenet-21k."""
2929
config = mlc.ConfigDict()
3030

31-
config.dataset = 'imagenet21k'
32-
config.train_split = 'full[51200:]'
31+
config.seed = 0
32+
config.total_epochs = 90
3333
config.num_classes = 21843
3434
config.init_head_bias = -10.0
3535
config.loss = 'sigmoid_xent'
3636

37-
config.trial = 0
38-
config.batch_size = 4096
39-
config.total_epochs = 90
37+
config.input = dict()
38+
config.input.data = dict(
39+
name='imagenet21k',
40+
split='full[51200:]',
41+
)
42+
config.input.batch_size = 4096
43+
config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
4044

4145
pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")'
4246
pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}')
4347
pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"')
44-
config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k
45-
pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common_i21k
46-
pp_eval_i1k = 'decode|resize_small(256)|central_crop(224)' + pp_common_i1k
47-
config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
48+
config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k
49+
pp_eval = 'decode|resize_small(256)|central_crop(224)'
4850

4951
config.log_training_steps = 50
5052
config.ckpt_steps = 1000
@@ -58,22 +60,23 @@ def get_config():
5860
config.grad_clip_norm = 1.0
5961

6062
# linear scaling rule. Don't forget to sweep if sweeping batch_size.
61-
config.lr = (0.03 / 256) * config.batch_size
62-
config.wd = (3e-5 / 256) * config.batch_size
63+
config.lr = (0.03 / 256) * config.input.batch_size
64+
config.wd = (3e-5 / 256) * config.input.batch_size
6365
config.schedule = dict(decay_type='cosine', warmup_steps=5000)
6466

65-
# Eval section
66-
eval_common = dict(
67-
type='classification',
68-
dataset=config.dataset,
69-
pp_fn=pp_eval,
70-
loss_name=config.loss,
71-
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
72-
)
67+
# Evaluations on i21k itself.
68+
def eval_i21k(split):
69+
return dict(
70+
type='classification',
71+
data={**config.input.data, 'split': split},
72+
pp_fn=pp_eval + pp_common_i21k,
73+
loss_name=config.loss,
74+
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
75+
)
7376
config.evals = {}
74-
config.evals.test = {**eval_common, 'split': 'full[:25_600]'}
75-
config.evals.val = {**eval_common, 'split': 'full[25_600:51_200]'}
76-
config.evals.train = {**eval_common, 'split': 'full[51_200:76_800]'}
77+
config.evals.test = eval_i21k('full[:25_600]')
78+
config.evals.val = eval_i21k('full[25_600:51_200]')
79+
config.evals.train = eval_i21k('full[51_200:76_800]')
7780

7881
# Few-shot evaluators
7982
config.evals.fewshot = get_fewshot_lsr()

big_vision/configs/common_fewshot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_fewshot_lsr(target_resolution=224, resize_resolution=256,
3939
}
4040
config.pp_train = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)|keep("image", "label")'
4141
config.pp_eval = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)|keep("image", "label")'
42-
config.shots = [1, 5, 10, 25]
42+
config.shots = (1, 5, 10, 25)
4343
config.l2_reg = 2.0 ** 10
4444
config.num_seeds = 3
4545
config.display_first = [('imagenet', 10)] if not runlocal else [('pets', 10)]

big_vision/configs/load_and_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import big_vision.configs.common as bvcc
3838
from big_vision.configs.common_fewshot import get_fewshot_lsr
39+
from big_vision.configs.proj.image_text import lit_eval
3940
import ml_collections as mlc
4041

4142

big_vision/configs/mlp_mixer_i1k.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,37 +31,39 @@ def get_config(mode=None):
3131
"""Config for training Mixer on i1k."""
3232
config = mlc.ConfigDict()
3333

34-
config.dataset = 'imagenet2012'
35-
config.train_split = 'train[:99%]'
36-
config.cache_raw = True # Needs up to 120GB of RAM!
34+
config.seed = 0
35+
config.total_epochs = 300
3736
config.num_classes = 1000
38-
config.init_head_bias = -6.9
3937
config.loss = 'sigmoid_xent'
38+
config.init_head_bias = -6.9
39+
40+
config.input = dict()
41+
config.input.data = dict(
42+
name='imagenet2012',
43+
split='train[:99%]',
44+
)
45+
config.input.batch_size = 4096
46+
config.input.cache_raw = True # Needs up to 120GB of RAM!
47+
config.input.shuffle_buffer_size = 250_000
4048

41-
config.pp_train = (
49+
config.input.pp = (
4250
'decode_jpeg_and_inception_crop(224)'
4351
'|flip_lr'
4452
'|randaug(2,15)'
4553
'|value_range(-1, 1)'
4654
'|onehot(1000, key="label", key_result="labels")'
4755
'|keep("image", "labels")'
4856
)
49-
ppv = (
57+
pp_eval = (
5058
'decode'
5159
'|resize_small(256)|central_crop(224)'
5260
'|value_range(-1, 1)'
5361
'|onehot(1000, key="{lbl}", key_result="labels")'
5462
'|keep("image", "labels")'
5563
)
5664

57-
config.batch_size = 4096
58-
config.total_epochs = 300
59-
60-
config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
61-
6265
config.log_training_steps = 50
6366
config.ckpt_steps = 1000
64-
config.ckpt_timeout = 1
6567

6668
config.prefetch_to_device = 2
6769

@@ -86,30 +88,29 @@ def get_config(mode=None):
8688
)
8789

8890
# Eval section
89-
eval_common = dict(
90-
type='classification',
91-
dataset='imagenet2012',
92-
pp_fn=ppv.format(lbl='label'),
93-
loss_name=config.loss,
94-
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
95-
)
91+
def get_eval(split, dataset='imagenet2012'):
92+
return dict(
93+
type='classification',
94+
data=dict(name=dataset, split=split),
95+
pp_fn=pp_eval.format(lbl='label'),
96+
loss_name=config.loss,
97+
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
98+
cache_final=mode != 'gpu8',
99+
)
96100
config.evals = {}
97-
config.evals.train = {**eval_common, 'split': 'train[:2%]'}
98-
config.evals.minival = {**eval_common, 'split': 'train[99%:]'}
99-
config.evals.val = {**eval_common, 'split': 'validation'}
100-
config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'}
101-
102-
config.evals.real = dict(**eval_common)
103-
config.evals.real.dataset = 'imagenet2012_real'
104-
config.evals.real.split = 'validation'
105-
config.evals.real.pp_fn = ppv.format(lbl='real_label')
101+
config.evals.train = get_eval('train[:2%]')
102+
config.evals.minival = get_eval('train[99%:]')
103+
config.evals.val = get_eval('validation')
104+
config.evals.v2 = get_eval('test', dataset='imagenet_v2')
105+
config.evals.real = get_eval('validation', dataset='imagenet2012_real')
106+
config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
106107

107108
config.fewshot = get_fewshot_lsr()
108109

109110
if mode == 'gpu8':
110111
config.total_epochs = 60
111-
config.batch_size = 512
112-
config.cache_raw = False
112+
config.input.batch_size = 512
113+
config.input.cache_raw = False
113114
if mode == 'regression_test':
114115
config.total_epochs = 60
115116

big_vision/configs/proj/distill/bigsweep_flowers_pet.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,21 @@ def get_config(arg=None):
5050
arg = bvcc.parse_arg(arg, runlocal=False, data='flowers', variant='medium', crop='inception_crop(128)')
5151
config = mlc.ConfigDict()
5252

53-
config.dataset = dict(flowers='oxford_flowers102', pet='oxford_iiit_pet')[arg.data]
54-
config.cache_raw = True
53+
config.input = {}
54+
config.input.data = dict(
55+
name=dict(flowers='oxford_flowers102', pet='oxford_iiit_pet')[arg.data],
56+
split=dict(flowers='train', pet='train[:90%]')[arg.data],
57+
)
58+
config.input.batch_size = 512
59+
config.input.cache_raw = True
60+
config.input.shuffle_buffer_size = 50_000
5561
config.prefetch_to_device = 4
56-
config.train_split = dict(flowers='train', pet='train[:90%]')[arg.data]
57-
config.num_classes = NCLS[arg.data]
5862

59-
config.batch_size = 512
60-
config.num_epochs = {
63+
config.num_classes = NCLS[arg.data]
64+
config.total_epochs = {
6165
'flowers': {'fast': 10_000, 'medium': 100_000, 'long': 1_000_000},
6266
'pet': {'fast': 1000, 'medium': 3000, 'long': 30_000},
6367
}[arg.data][arg.variant]
64-
config.shuffle_buffer_size = 50_000
6568

6669
config.log_training_steps = 100
6770
config.ckpt_steps = 2500
@@ -81,7 +84,7 @@ def get_config(arg=None):
8184
f'|onehot({config.num_classes}, key="label", key_result="labels")'
8285
'|keep("image", "labels")'
8386
)
84-
config.pp_train = f'decode|{arg.crop}|flip_lr' + pp_common
87+
config.input.pp = f'decode|{arg.crop}|flip_lr' + pp_common
8588
ppv = 'decode|resize_small(160)|central_crop(128)' + pp_common
8689

8790
config.mixup = dict(p=1.0, n=2)
@@ -118,18 +121,19 @@ def get_config(arg=None):
118121
val_split = 'train[90%:]' if not arg.runlocal else 'train[:16]'
119122
test_split = 'test' if not arg.runlocal else 'test[:16]'
120123

121-
base = dict(
122-
type='classification',
123-
pred='student_fwd',
124-
dataset=config.dataset,
125-
pp_fn=ppv,
126-
loss_name='softmax_xent',
127-
log_steps=500,
128-
)
124+
def get_eval(split):
125+
return dict(
126+
type='classification',
127+
pred='student_fwd',
128+
data=dict(name=config.input.data.name, split=split),
129+
pp_fn=ppv,
130+
loss_name='softmax_xent',
131+
log_steps=500,
132+
)
129133
config.evals = {}
130-
config.evals.student_train = {**base, 'split': minitrain_split}
131-
config.evals.student_val = {**base, 'split': val_split}
132-
config.evals.student_test = {**base, 'split': test_split}
134+
config.evals.student_train = get_eval(minitrain_split)
135+
config.evals.student_val = get_eval(val_split)
136+
config.evals.student_test = get_eval(test_split)
133137

134138
# Teacher is fixed, so rare evals.
135139
teacher = dict(log_steps=100_000, pred='prof_m_fwd')
@@ -138,22 +142,23 @@ def get_config(arg=None):
138142
config.evals.teacher_test = {**config.evals.student_test, **teacher}
139143

140144
# Could in principle also look at agreement on other datasets!
141-
dist = dict(
142-
type='proj.distill.distance',
143-
pred='student_prof_m_fwd',
144-
dataset=config.dataset,
145-
pp_fn=ppv + '|keep("image")',
146-
log_steps=1000,
147-
distances=({'kind': 'kl'}, {'kind': 'euclidean'},
148-
{'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}),
149-
)
150-
config.evals.dist_train = {**dist, 'split': minitrain_split}
151-
config.evals.dist_val = {**dist, 'split': val_split}
152-
config.evals.dist_test = {**dist, 'split': test_split}
145+
def get_dist(split):
146+
return dict(
147+
type='proj.distill.distance',
148+
pred='student_prof_m_fwd',
149+
data=dict(name=config.input.data.name, split=split),
150+
pp_fn=ppv + '|keep("image")',
151+
log_steps=1000,
152+
distances=({'kind': 'kl'}, {'kind': 'euclidean'},
153+
{'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}),
154+
)
155+
config.evals.dist_train = get_dist(minitrain_split)
156+
config.evals.dist_val = get_dist(val_split)
157+
config.evals.dist_test = get_dist(test_split)
153158

154159
# Make a few things much smaller for quick local debugging testruns.
155160
if arg.runlocal:
156-
config.shuffle_buffer_size = 10
157-
config.batch_size = 8
161+
config.input.shuffle_buffer_size = 10
162+
config.input.batch_size = 8
158163

159164
return config

0 commit comments

Comments
 (0)