@@ -50,18 +50,21 @@ def get_config(arg=None):
5050 arg = bvcc .parse_arg (arg , runlocal = False , data = 'flowers' , variant = 'medium' , crop = 'inception_crop(128)' )
5151 config = mlc .ConfigDict ()
5252
53- config .dataset = dict (flowers = 'oxford_flowers102' , pet = 'oxford_iiit_pet' )[arg .data ]
54- config .cache_raw = True
53+ config .input = {}
54+ config .input .data = dict (
55+ name = dict (flowers = 'oxford_flowers102' , pet = 'oxford_iiit_pet' )[arg .data ],
56+ split = dict (flowers = 'train' , pet = 'train[:90%]' )[arg .data ],
57+ )
58+ config .input .batch_size = 512
59+ config .input .cache_raw = True
60+ config .input .shuffle_buffer_size = 50_000
5561 config .prefetch_to_device = 4
56- config .train_split = dict (flowers = 'train' , pet = 'train[:90%]' )[arg .data ]
57- config .num_classes = NCLS [arg .data ]
5862
59- config .batch_size = 512
60- config .num_epochs = {
63+ config .num_classes = NCLS [ arg . data ]
64+ config .total_epochs = {
6165 'flowers' : {'fast' : 10_000 , 'medium' : 100_000 , 'long' : 1_000_000 },
6266 'pet' : {'fast' : 1000 , 'medium' : 3000 , 'long' : 30_000 },
6367 }[arg .data ][arg .variant ]
64- config .shuffle_buffer_size = 50_000
6568
6669 config .log_training_steps = 100
6770 config .ckpt_steps = 2500
@@ -81,7 +84,7 @@ def get_config(arg=None):
8184 f'|onehot({ config .num_classes } , key="label", key_result="labels")'
8285 '|keep("image", "labels")'
8386 )
84- config .pp_train = f'decode|{ arg .crop } |flip_lr' + pp_common
87+ config .input . pp = f'decode|{ arg .crop } |flip_lr' + pp_common
8588 ppv = 'decode|resize_small(160)|central_crop(128)' + pp_common
8689
8790 config .mixup = dict (p = 1.0 , n = 2 )
@@ -118,18 +121,19 @@ def get_config(arg=None):
118121 val_split = 'train[90%:]' if not arg .runlocal else 'train[:16]'
119122 test_split = 'test' if not arg .runlocal else 'test[:16]'
120123
121- base = dict (
122- type = 'classification' ,
123- pred = 'student_fwd' ,
124- dataset = config .dataset ,
125- pp_fn = ppv ,
126- loss_name = 'softmax_xent' ,
127- log_steps = 500 ,
128- )
124+ def get_eval (split ):
125+ return dict (
126+ type = 'classification' ,
127+ pred = 'student_fwd' ,
128+ data = dict (name = config .input .data .name , split = split ),
129+ pp_fn = ppv ,
130+ loss_name = 'softmax_xent' ,
131+ log_steps = 500 ,
132+ )
129133 config .evals = {}
130- config .evals .student_train = { ** base , 'split' : minitrain_split }
131- config .evals .student_val = { ** base , 'split' : val_split }
132- config .evals .student_test = { ** base , 'split' : test_split }
134+ config .evals .student_train = get_eval ( minitrain_split )
135+ config .evals .student_val = get_eval ( val_split )
136+ config .evals .student_test = get_eval ( test_split )
133137
134138 # Teacher is fixed, so rare evals.
135139 teacher = dict (log_steps = 100_000 , pred = 'prof_m_fwd' )
@@ -138,22 +142,23 @@ def get_config(arg=None):
138142 config .evals .teacher_test = {** config .evals .student_test , ** teacher }
139143
140144 # Could in principle also look at agreement on other datasets!
141- dist = dict (
142- type = 'proj.distill.distance' ,
143- pred = 'student_prof_m_fwd' ,
144- dataset = config .dataset ,
145- pp_fn = ppv + '|keep("image")' ,
146- log_steps = 1000 ,
147- distances = ({'kind' : 'kl' }, {'kind' : 'euclidean' },
148- {'kind' : 'agree' , 'k' : 1 }, {'kind' : 'agree' , 'k' : 5 }),
149- )
150- config .evals .dist_train = {** dist , 'split' : minitrain_split }
151- config .evals .dist_val = {** dist , 'split' : val_split }
152- config .evals .dist_test = {** dist , 'split' : test_split }
145+ def get_dist (split ):
146+ return dict (
147+ type = 'proj.distill.distance' ,
148+ pred = 'student_prof_m_fwd' ,
149+ data = dict (name = config .input .data .name , split = split ),
150+ pp_fn = ppv + '|keep("image")' ,
151+ log_steps = 1000 ,
152+ distances = ({'kind' : 'kl' }, {'kind' : 'euclidean' },
153+ {'kind' : 'agree' , 'k' : 1 }, {'kind' : 'agree' , 'k' : 5 }),
154+ )
155+ config .evals .dist_train = get_dist (minitrain_split )
156+ config .evals .dist_val = get_dist (val_split )
157+ config .evals .dist_test = get_dist (test_split )
153158
154159 # Make a few things much smaller for quick local debugging testruns.
155160 if arg .runlocal :
156- config .shuffle_buffer_size = 10
157- config .batch_size = 8
161+ config .input . shuffle_buffer_size = 10
162+ config .input . batch_size = 8
158163
159164 return config
0 commit comments