-
Notifications
You must be signed in to change notification settings - Fork 213
implement gsam in jax #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
c6a0dda
67ff195
8cde017
dcd0859
e54eb6a
28b5c4a
f14f615
b0375e1
5736be2
24bec3f
a9a696c
8dec7e0
8dd3364
f420943
3c6e55e
a19e92f
e06c79b
cba60f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| # Copyright 2022 Big Vision Authors. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # pylint: disable=line-too-long | ||
| r"""Pre-training ViT on ILSVRC-2012 as in https://arxiv.org/abs/2203.08065 | ||
|
|
||
| This configuration makes use of the "arg" to get_config to select which model | ||
| to run, so a few examples are given below: | ||
|
|
||
| Run training of a B/16 model: | ||
|
|
||
| big_vision.train \ | ||
| --config big_vision/configs/vit_i1k.py:variant=B/16 \ | ||
| --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` | ||
|
|
||
| Run training of a B/32 model without aug-strenght and 300ep: | ||
|
|
||
| big_vision.train \ | ||
| --config big_vision/configs/vit_i1k.py:variant=B/32,aug=none \ | ||
| --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \ | ||
| --config.num_epochs 300 | ||
| """ | ||
|
|
||
| import big_vision.configs.common as bvcc | ||
| from big_vision.configs.common_fewshot import get_fewshot_lsr | ||
| import ml_collections as mlc | ||
|
|
||
| def get_config(arg=None): | ||
| """Config for training.""" | ||
| arg = bvcc.parse_arg(arg, variant='B/32', runlocal=False, aug='') | ||
|
||
| config = mlc.ConfigDict() | ||
|
|
||
| config.dataset = 'imagenet2012' | ||
| config.train_split = 'train[:99%]' | ||
| config.cache_raw = not arg.runlocal # Needs up to 120GB of RAM! | ||
| config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. | ||
| config.num_classes = 1000 | ||
| config.loss = 'sigmoid_xent' | ||
| config.batch_size = 4096 | ||
| config.num_epochs = 300 | ||
|
|
||
| pp_common = ( | ||
| '|value_range(-1, 1)' | ||
| '|onehot(1000, key="{lbl}", key_result="labels")' | ||
| '|keep("image", "labels")' | ||
| ) | ||
| config.pp_train = ( | ||
| 'decode_jpeg_and_inception_crop(224)|flip_lr|' + | ||
| pp_common.format(lbl='label') | ||
| ) | ||
| pp = 'decode|resize_small(256)|central_crop(224)' + pp_common | ||
|
|
||
| # Aggressive pre-fetching because our models here are small, so we not only | ||
| # can afford it, but we also need it for the smallest models to not be | ||
| # bottle-necked by the input pipeline. Play around with it for -L models tho. | ||
| config.prefetch_to_host = 8 | ||
| config.prefetch_to_device = 4 | ||
|
|
||
| config.log_training_steps = 50 | ||
| config.checkpoint_steps = 1000 | ||
|
|
||
| # Model section | ||
| config.model_name = 'vit' | ||
| config.model = dict( | ||
| variant=arg.variant, | ||
| rep_size=False, | ||
| pool_type='gap', | ||
| ) | ||
|
|
||
| # Optimizer section | ||
| config.grad_clip_norm = 1.0 | ||
| config.optax_name = 'scale_by_adam' | ||
| config.optax = dict(mu_dtype='bfloat16') | ||
| # The modified AdaFactor we introduced in https://arxiv.org/abs/2106.04560 | ||
| # almost always behaves exactly like adam, but at a fraction of the memory | ||
| # cost (specifically, adam_bf16 = +1.5M, adafactor = +0.5M), hence it is a | ||
| # good idea to try it when you are memory-bound! | ||
| # config.optax_name = 'big_vision.scale_by_adafactor' | ||
| # A good flag to play with when hitting instabilities, is the following: | ||
| # config.optax = dict(beta2_cap=0.95) | ||
|
|
||
| config.lr = 0.003 | ||
| config.wd = 0.001 # default is 0.0001; paper used 0.3, effective wd=0.3*lr | ||
| config.schedule = dict( | ||
| warmup_steps=10_000, | ||
| decay_type='linear', | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe append a short inline comment |
||
| linear_end=0.00003, | ||
| ) | ||
|
|
||
| # GSAM settings. | ||
| # Note: when rho_max=rho_min and alpha=0, GSAM reduces to SAM. | ||
| config.gsam = dict( | ||
| rho_max=0.6, | ||
| rho_min=0.1, | ||
| alpha=0.6, | ||
| adaptive_perturbation=False, | ||
| minimize_fp=True, | ||
|
||
| ) | ||
|
|
||
| # Eval section | ||
| eval_common = dict( | ||
| type='classification', | ||
| dataset='imagenet2012', | ||
| pp_fn=pp.format(lbl='label'), | ||
| loss_name=config.loss, | ||
| log_steps=2500, # Very fast O(seconds) so it's fine to run it often. | ||
| ) | ||
| config.evals = {} | ||
| config.evals.train = {**eval_common, 'split': 'train[:2%]'} | ||
| config.evals.minival = {**eval_common, 'split': 'train[99%:]'} | ||
| config.evals.val = {**eval_common, 'split': 'validation'} | ||
| config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'} | ||
|
|
||
| config.evals.real = {**eval_common} | ||
| config.evals.real.dataset = 'imagenet2012_real' | ||
| config.evals.real.split = 'validation' | ||
| config.evals.real.pp_fn = pp.format(lbl='real_label') | ||
|
|
||
| config.fewshot = get_fewshot_lsr(runlocal=arg.runlocal) | ||
| config.fewshot.log_steps = 10_000 | ||
|
|
||
| # Make a few things much smaller for quick local debugging testruns. | ||
| if arg.runlocal: | ||
| config.shuffle_buffer_size = 10 | ||
| config.batch_size = 8 | ||
| config.minival.split = 'train[:16]' | ||
| config.val.split = 'validation[:16]' | ||
| config.real.split = 'validation[:16]' | ||
| config.v2.split = 'test[:16]' | ||
|
|
||
| return config | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
| def dual_vector(y): | ||
| """Returns the solution of max_x y^T x s.t. ||x||_2 <= 1. | ||
| Args: | ||
| y: A pytree of numpy ndarray, vector y in the equation above. | ||
| """ | ||
| gradient_norm = jnp.sqrt(sum( | ||
| [jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)])) | ||
| normalized_gradient = jax.tree_map(lambda x: x / gradient_norm, y) | ||
| return normalized_gradient, gradient_norm | ||
|
|
||
| def gsam_gradient(loss_fn, base_opt, inputs, targets, | ||
| rho_max, rho_min, alpha, lr, lr_max, lr_min, eps=1e-12, | ||
| adaptive_perturbation=False, minimize_fp=True): | ||
| """ | ||
| Get the GSAM gradient (https://openreview.net/pdf?id=edONMAnhLu-) of the loss function. | ||
| Args: | ||
| loss_fn: the loss function. | ||
| base_opt: the base optimizer. | ||
|
||
| inputs: the inputs to the loss function. | ||
| targets: the targets to the loss function. | ||
| rho_max: the maximum rho value for perturbation of weights. | ||
| rho_min: the minimum rho value for perturbation of weights. | ||
| alpha: the alpha value for the rho schedule, see Algorithm 1 in the paper. | ||
| lr: current learning rate. | ||
| lr_max: the maximum learning rate. | ||
| lr_min: the minimum learning rate. | ||
| eps: the epsilon value for numerical stability. | ||
| adaptive_perturbation: if False, same perturbation as SAM, treat all parameters as a single vector, | ||
| perturbation norm is calculated as the norm of the whole vector; | ||
| if True, for each parameter tensor p, perturbation is element-wise multiplied by abs(p). | ||
| minimize_fp: if True, min(f_p, h), original GSAM; if False, min(f, h), where f is the clean loss, | ||
| f_p is the perturbed loss, h is the surrogate gap. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The doc comments of both For example (the example is clearly wrong, because I don't understand them, but just to show the spirit of what I'm looking for): |
||
| Returns: | ||
| l_clean: the loss function value. | ||
| g_gsam: the GSAM gradient. g_gsam is not averaged across workers, need to call "jax.lax.pmean" to average. | ||
| Note: | ||
| Setting `rho_max=rho_min` and `alpha=0` reduces GSAM to SAM. | ||
| """ | ||
| l_clean, g_clean = jax.value_and_grad(loss_fn)(base_opt.target, inputs, targets) | ||
| g_clean_normalized, g_clean_length = dual_vector(g_clean) | ||
|
|
||
| if lr_max == lr_min: | ||
| sam_rho = rho_max | ||
| else: | ||
| sam_rho = rho_min + (rho_max - rho_min) * (lr - lr_min) / (lr_max - lr_min) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From #4: Lucas:
Juntang:
Ah this is really unfortunate, there should be a much cleaner way to implement this eg using a squashed version of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, sorry I did not explain this clearly. Suppose learning rate is I tried to use a separate I wonder if you have any suggestions for a neater version using |
||
|
|
||
| # Per-worker perturbation. | ||
| if adaptive_perturbation: | ||
| param_sam = jax.tree_multimap(lambda a, b: a + jnp.abs(a) * sam_rho * b / (g_clean_length + eps), | ||
|
||
| base_opt.target, g_clean) | ||
| else: | ||
| param_sam = jax.tree_multimap(lambda a, b: a + sam_rho * b / (g_clean_length + eps), | ||
| base_opt.target, g_clean) | ||
|
|
||
| # Get gradients at perturbed weights. | ||
| l_robust, g_robust = jax.value_and_grad(loss_fn)(param_sam, inputs, targets) | ||
|
|
||
| # Decompose gradients. | ||
| g_clean_flatten, _ = jax.tree_flatten(g_clean) | ||
| g_robust_flatten, _ = jax.tree_flatten(g_robust) | ||
|
|
||
| if minimize_fp: | ||
| # Decompose g_clean onto parallel and vertical to g_robust. | ||
| g_robust_normalized, g_robust_length = dual_vector(g_robust) | ||
| g_robust_normalized_flatten, _ = jax.tree_flatten(g_robust_normalized) | ||
|
|
||
| g_clean_projection_norm = sum([jnp.vdot(p, q) for (p,q) in | ||
| zip(g_robust_normalized_flatten, g_clean_flatten)]) | ||
| g_clean_residual = jax.tree_multimap(lambda a, b: | ||
| a - g_clean_projection_norm * b, g_clean, g_robust_normalized) | ||
|
|
||
| # Get GSAM gradient. | ||
| g_gsam = jax.tree_multimap(lambda a, b: a - b * alpha, | ||
| g_robust, g_clean_residual) | ||
| else: | ||
| # Decompose g_robust onto parallel and vertical to g_clean. | ||
| g_clean_normalized, g_clean_length = dual_vector(g_clean) | ||
| g_clean_normalized_flatten, _ = jax.tree_flatten(g_clean_normalized) | ||
|
|
||
| g_robust_projection_norm = sum([jnp.vdot(p, q) for (p,q) in | ||
| zip(g_clean_normalized_flatten, g_robust_flatten)]) | ||
| g_robust_residual = jax.tree_multimap(lambda a, b: | ||
| a - g_robust_projection_norm * b, g_robust, g_clean_normalized) | ||
|
|
||
| # Get GSAM gradient. | ||
| g_gsam = jax.tree_multimap(lambda a, b: a + b * alpha, | ||
| g_clean, g_robust_residual) | ||
|
|
||
| # Always return the clean loss (rather than the perturbed loss). | ||
| return l_clean, g_gsam | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of these example commands need to be updated to this config file.