Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright 2022 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=line-too-long
r"""Pre-training ViT on ILSVRC-2012 with GSAM in https://arxiv.org/abs/2203.08065

Run training of a B/32 model:

big_vision.trainers.proj.gsam.train \
--config big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py \
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`

"""

import big_vision.configs.common as bvcc
from big_vision.configs.common_fewshot import get_fewshot_lsr
import ml_collections as mlc

def get_config(arg=None):
"""Config for training."""
arg = bvcc.parse_arg(arg, variant='B/32', runlocal=False)
config = mlc.ConfigDict()

config.dataset = 'imagenet2012'
config.train_split = 'train[:99%]'
config.cache_raw = not arg.runlocal # Needs up to 120GB of RAM!
config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
config.num_classes = 1000
config.loss = 'sigmoid_xent'
config.batch_size = 4096
config.num_epochs = 300

pp_common = (
'|value_range(-1, 1)'
'|onehot(1000, key="{lbl}", key_result="labels")'
'|keep("image", "labels")'
)
config.pp_train = (
'decode_jpeg_and_inception_crop(224)|flip_lr|' +
pp_common.format(lbl='label')
)
pp = 'decode|resize_small(256)|central_crop(224)' + pp_common

# Aggressive pre-fetching because our models here are small, so we not only
# can afford it, but we also need it for the smallest models to not be
# bottle-necked by the input pipeline. Play around with it for -L models tho.
config.prefetch_to_host = 8
config.prefetch_to_device = 4

config.log_training_steps = 50
config.checkpoint_steps = 1000

# Model section
config.model_name = 'vit'
config.model = dict(
variant=arg.variant,
rep_size=False,
pool_type='gap',
)
config.init_head_bias = -10.0

# Optimizer section
config.grad_clip_norm = 1.0
config.optax_name = 'scale_by_adam'
config.optax = dict(mu_dtype='float32')
# The modified AdaFactor we introduced in https://arxiv.org/abs/2106.04560
# almost always behaves exactly like adam, but at a fraction of the memory
# cost (specifically, adam_bf16 = +1.5M, adafactor = +0.5M), hence it is a
# good idea to try it when you are memory-bound!
# config.optax_name = 'big_vision.scale_by_adafactor'
# A good flag to play with when hitting instabilities, is the following:
# config.optax = dict(beta2_cap=0.95)

config.lr = 0.003
config.wd = 0.001 # default is 0.0001; paper used 0.3, effective wd=0.3*lr
config.schedule = dict(
warmup_steps=10_000,
decay_type='linear',
linear_end=0.01,
)

# GSAM settings.
# Note: when rho_max=rho_min and alpha=0, GSAM reduces to SAM.
config.gsam = dict(
rho_max=0.6,
rho_min=0.1,
alpha=0.6,
adaptive_perturbation=False,
minimize_fp=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those two (adaptive_perturbation and minimize_fp) are set to their default values. From the doc-comment and paper, it does not seem like something a regular user would tune (contrary to rho and alpha), so let's remove them fromt he config?

lr_max=config.get_ref('lr'),
lr_min=config.schedule.get_ref('linear_end') * config.get_ref('lr'),
)

# Eval section
eval_common = dict(
type='classification',
dataset='imagenet2012',
pp_fn=pp.format(lbl='label'),
loss_name=config.loss,
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
)
config.evals = {}
config.evals.train = {**eval_common, 'split': 'train[:2%]'}
config.evals.minival = {**eval_common, 'split': 'train[99%:]'}
config.evals.val = {**eval_common, 'split': 'validation'}
config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'}

config.evals.real = {**eval_common}
config.evals.real.dataset = 'imagenet2012_real'
config.evals.real.split = 'validation'
config.evals.real.pp_fn = pp.format(lbl='real_label')

config.fewshot = get_fewshot_lsr(runlocal=arg.runlocal)
config.fewshot.log_steps = 10_000

# Make a few things much smaller for quick local debugging testruns.
if arg.runlocal:
config.shuffle_buffer_size = 10
config.batch_size = 8
config.minival.split = 'train[:16]'
config.val.split = 'validation[:16]'
config.real.split = 'validation[:16]'
config.v2.split = 'test[:16]'

return config
102 changes: 102 additions & 0 deletions big_vision/trainers/proj/gsam/gsam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
'''This file provides jax implementation of GSAM.'''

import jax
import jax.numpy as jnp

def dual_vector(y):
"""Returns the solution of max_x y^T x s.t. ||x||_2 <= 1.
Args:
y: A pytree of numpy ndarray, vector y in the equation above.
"""
gradient_norm = jnp.sqrt(sum(
jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)))
normalized_gradient = jax.tree_map(lambda x: x / gradient_norm, y)
return normalized_gradient, gradient_norm

def gsam_gradient(loss_fn, params, inputs, targets,
rho_max, rho_min, alpha, lr, lr_max, lr_min, eps=1e-12,
adaptive_perturbation=False, minimize_fp=True):
"""
Get the GSAM gradient (https://openreview.net/pdf?id=edONMAnhLu-).
Args:
loss_fn: the loss function.
params: the model weights.
inputs: the inputs to the loss function.
targets: the targets to the loss function.
rho_max: the maximum rho value for perturbation of weights.
rho_min: the minimum rho value for perturbation of weights.
alpha: the alpha value for the rho schedule, see Algorithm 1 in the paper.
lr: current learning rate.
lr_max: the maximum learning rate.
lr_min: the minimum learning rate.
eps: the epsilon value for numerical stability.
adaptive_perturbation: if False, same perturbation as SAM,
treat all parameters as a single vector,
perturbation norm is calculated as the norm of the whole vector;
if True, for each parameter tensor p,
perturbation is element-wise multiplied by abs(p).
minimize_fp: if True, min(f_p, h), original GSAM;
if False, min(f, h), where f is the clean loss.
f_p is the perturbed loss, h is the surrogate gap.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc comments of both adaptive_perturbation and minimize_fp both explain what they do in very technical terms, but it would be good to have a short high-level recommendation at the end as to when or why one would want to change them.

For example (the example is clearly wrong, because I don't understand them, but just to show the spirit of what I'm looking for):

    adaptive_perturbation: if False, same perturbation as SAM,
        treat all parameters as a single vector,
        perturbation norm is calculated as the norm of the whole vector;
        if True, for each parameter tensor p,
        perturbation is element-wise multiplied by abs(p).
        Try setting this to False when you use least-squares loss instead of KL-based ones.
    minimize_fp: if True, min(f_p, h), original GSAM;
        if False, min(f, h), where f is the clean loss.
        f_p is the perturbed loss, h is the surrogate gap.
        You probably want to leave this at its default unless you know what you're doing.

Returns:
l_clean: the loss function value.
g_gsam: the GSAM gradient. g_gsam is not averaged across workers,
need to call "jax.lax.pmean" to average.

Note:
Setting `rho_max=rho_min` and `alpha=0` reduces GSAM to SAM.
"""
l_clean, g_clean = jax.value_and_grad(loss_fn)(params, inputs, targets)
g_clean_normalized, g_clean_length = dual_vector(g_clean)

if lr_max == lr_min:
sam_rho = rho_max
else:
sam_rho = rho_min + (rho_max - rho_min) * (lr - lr_min) / (lr_max - lr_min)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From #4:

Lucas:

This makes me wonder (sorry I haven't read the GSAM paper), do you really want to linearly interpolate rho, or would you ideally want to apply the same scheduling function as the learning-rate, e.g. cosine for example?

Juntang:

Sorry for the confusion. I want to apply the same scheduler but with a different scale / upper_lower bound.
In the paper I only used linear lr scheduler for experiments, and in theory (and proofs part of paper) the two schedules are assumed to be both of inverse sqrt.

Ah this is really unfortunate, there should be a much cleaner way to implement this eg using a squashed version of sched_fns from the trainer!
But if you don't want to change the code to do this, then you should put an assert config.schedule.decay_type == "linear", "GSAM only implemented for linear lr schedule" into the train.py and add a little comment here in the code that goes something like

# Ideally, we'd use the same schedule as the lr here, just stretched to a different min/max.
# However, here we hard-code the linear scheduler only for convenience.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, sorry I did not explain this clearly. Suppose learning rate is lr(t) for step t, and there's an effective rho(t) for each step t. The code restricts rho(t) to be linear w.r.t lr(t), however rho(t) is not linear w.r.t t. If we change lr(t) to be some non-linear schedule such as cosine, the code here will generate a rho(t) also in the shape of cosine, except lr_max != rho_max and lr_min != rho_min.

I tried to use a separate sched_fn for rho(t), but it seems some schedules such as cosine does not have the option to specify a non-zero min value rho_min.

I wonder if you have any suggestions for a neater version using sched_fn with configurable min value, or we keep the schedule code here?


# Per-worker perturbation.
if adaptive_perturbation:
param_sam = jax.tree_map(lambda a, b: a + \
jnp.abs(a) * sam_rho * b / (g_clean_length + eps), params, g_clean)
else:
param_sam = jax.tree_map(lambda a, b: a + \
sam_rho * b / (g_clean_length + eps), params, g_clean)

# Get gradients at perturbed weights.
_, g_robust = jax.value_and_grad(loss_fn)(param_sam, inputs, targets)

# Decompose gradients.
g_clean_flatten, _ = jax.tree_util.tree_flatten(g_clean)
g_robust_flatten, _ = jax.tree_util.tree_flatten(g_robust)

if minimize_fp:
# Decompose g_clean onto parallel and vertical to g_robust.
g_robust_normalized, _ = dual_vector(g_robust)
g_robust_normalized_flatten, _ = jax.tree_util.tree_flatten(
g_robust_normalized)

g_clean_projection_norm = sum(jnp.vdot(p, q) for (p,q) in
zip(g_robust_normalized_flatten, g_clean_flatten))
g_clean_residual = jax.tree_map(lambda a, b:
a - g_clean_projection_norm * b, g_clean, g_robust_normalized)

# Get GSAM gradient.
g_gsam = jax.tree_map(lambda a, b: a - b * alpha,
g_robust, g_clean_residual)
else:
# Decompose g_robust onto parallel and vertical to g_clean.
g_clean_normalized, g_clean_length = dual_vector(g_clean)
g_clean_normalized_flatten, _ = jax.tree_util.tree_flatten(
g_clean_normalized)

g_robust_projection_norm = sum(jnp.vdot(p, q) for (p,q) in
zip(g_clean_normalized_flatten, g_robust_flatten))
g_robust_residual = jax.tree_map(lambda a, b:
a - g_robust_projection_norm * b, g_robust, g_clean_normalized)

# Get GSAM gradient.
g_gsam = jax.tree_map(lambda a, b: a + b * alpha,
g_clean, g_robust_residual)

# Always return the clean loss (rather than the perturbed loss).
return l_clean, g_gsam
Loading