|
| 1 | +# pip install timm wandb tqdm datasets |
| 2 | +# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core |
| 3 | +# |
| 4 | +# python benchmarks_adam_8bit.py \ |
| 5 | +# --model "timm/vit_base_patch16_224.augreg_in21k" \ |
| 6 | +# --amp bf16 \ |
| 7 | +# --optim Adam |
| 8 | +# |
| 9 | +# To use bnb 8-bit optimizer, set --optim Adam8bitBnb. To use 8-bit optimizer implemented in torchao, set --optim Adam8bitAo |
| 10 | +# To profile and export chrome trace, set --profile |
| 11 | +# To enable cosine learning rate scheduler, set --cosine_lr_scheduler |
| 12 | + |
| 13 | +import argparse |
| 14 | +import math |
| 15 | +from contextlib import nullcontext |
| 16 | +from pathlib import Path |
| 17 | + |
| 18 | +import bitsandbytes as bnb |
| 19 | +import datasets |
| 20 | +import timm |
| 21 | +import torch |
| 22 | +import torch.nn.functional as F |
| 23 | +from torch.profiler import ProfilerActivity, profile |
| 24 | +from torch.utils.data import DataLoader |
| 25 | +from torchvision.transforms import v2 |
| 26 | +from tqdm import tqdm |
| 27 | + |
| 28 | +from torchao.prototype.optim_8bit import Adam8bit |
| 29 | + |
| 30 | + |
| 31 | +class CosineSchedule: |
| 32 | + def __init__(self, lr: float, total_steps: int, warmup: float = 0.05) -> None: |
| 33 | + self.lr = lr |
| 34 | + self.final_lr = 0 |
| 35 | + self.total_steps = total_steps |
| 36 | + self.warmup_steps = round(total_steps * warmup) |
| 37 | + |
| 38 | + def get_lr(self, step: int) -> float: |
| 39 | + if step < self.warmup_steps: |
| 40 | + return self.lr * step / self.warmup_steps |
| 41 | + if step < self.total_steps: |
| 42 | + progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) |
| 43 | + return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi)) |
| 44 | + return self.final_lr |
| 45 | + |
| 46 | + |
| 47 | +class WandbLogger: |
| 48 | + def __init__(self, args): |
| 49 | + if args.project is not None and not args.profile: |
| 50 | + import wandb |
| 51 | + |
| 52 | + Path("wandb_logs").mkdir(exist_ok=True) |
| 53 | + self.run = wandb.init(project=args.project, name=args.run_name, config=args, dir="wandb_logs") |
| 54 | + |
| 55 | + else: |
| 56 | + self.run = None |
| 57 | + |
| 58 | + def log(self, *args, **kwargs): |
| 59 | + if self.run is not None: |
| 60 | + self.run.log(*args, **kwargs) |
| 61 | + |
| 62 | + |
| 63 | +def get_parser(): |
| 64 | + parser = argparse.ArgumentParser() |
| 65 | + parser.add_argument("--model", required=True) |
| 66 | + |
| 67 | + parser.add_argument("--amp", default="none") |
| 68 | + parser.add_argument("--channels_last", action="store_true") |
| 69 | + parser.add_argument("--compile", action="store_true") |
| 70 | + |
| 71 | + parser.add_argument("--n_epochs", type=int, default=10) |
| 72 | + parser.add_argument("--batch_size", type=int, default=64) |
| 73 | + parser.add_argument("--n_workers", type=int, default=4) |
| 74 | + |
| 75 | + parser.add_argument("--optim", default="Adam") |
| 76 | + parser.add_argument("--lr", type=float, default=1e-4) |
| 77 | + parser.add_argument("--weight_decay", type=float, default=0) |
| 78 | + parser.add_argument("--cosine_lr_scheduler", action="store_true") |
| 79 | + |
| 80 | + parser.add_argument("--project") |
| 81 | + parser.add_argument("--run_name", default="debug") |
| 82 | + parser.add_argument("--profile", action="store_true") |
| 83 | + return parser |
| 84 | + |
| 85 | + |
| 86 | +def get_dloader(args, training: bool): |
| 87 | + transforms = [v2.ToImage()] |
| 88 | + |
| 89 | + if training: |
| 90 | + transforms.extend([v2.RandomResizedCrop(224), v2.RandomHorizontalFlip()]) |
| 91 | + else: |
| 92 | + transforms.extend([v2.Resize(256), v2.CenterCrop(224)]) |
| 93 | + |
| 94 | + transforms.append(v2.ToDtype(torch.float32, scale=True)) |
| 95 | + transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) |
| 96 | + transforms = v2.Compose(transforms) |
| 97 | + |
| 98 | + # use dataset from HF so download is fast |
| 99 | + ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation") |
| 100 | + ds = ds.select_columns(["image", "label"]) |
| 101 | + ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"])) |
| 102 | + |
| 103 | + return DataLoader( |
| 104 | + ds, |
| 105 | + batch_size=args.batch_size, |
| 106 | + shuffle=training, |
| 107 | + num_workers=args.n_workers, |
| 108 | + pin_memory=training, |
| 109 | + drop_last=training, |
| 110 | + ) |
| 111 | + |
| 112 | + |
| 113 | +def get_amp_ctx(amp): |
| 114 | + dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp] |
| 115 | + return torch.autocast("cuda", dtype=dtype, enabled=amp != "none") |
| 116 | + |
| 117 | + |
| 118 | +@torch.no_grad() |
| 119 | +def evaluate_model(model, args): |
| 120 | + model.eval() |
| 121 | + val_dloader = get_dloader(args, False) |
| 122 | + |
| 123 | + all_labels = [] |
| 124 | + all_preds = [] |
| 125 | + |
| 126 | + for batch in tqdm(val_dloader, dynamic_ncols=True, desc=f"Evaluating"): |
| 127 | + all_labels.append(batch["label"].clone()) |
| 128 | + if args.channels_last: |
| 129 | + batch["image"] = batch["image"].to(memory_format=torch.channels_last) |
| 130 | + |
| 131 | + with get_amp_ctx(args.amp): |
| 132 | + all_preds.append(model(batch["image"].cuda()).argmax(1).cpu()) |
| 133 | + |
| 134 | + all_labels = torch.cat(all_labels, dim=0) |
| 135 | + all_preds = torch.cat(all_preds, dim=0) |
| 136 | + |
| 137 | + acc = (all_labels == all_preds).float().mean() |
| 138 | + return acc |
| 139 | + |
| 140 | + |
| 141 | +if __name__ == "__main__": |
| 142 | + args = get_parser().parse_args() |
| 143 | + |
| 144 | + if args.profile: |
| 145 | + args.n_epochs = 1 |
| 146 | + |
| 147 | + for k, v in vars(args).items(): |
| 148 | + print(f"{k}: {v}") |
| 149 | + |
| 150 | + # wandb is only enabled when args.project is set and args.profile is False |
| 151 | + logger = WandbLogger(args) |
| 152 | + dloader = get_dloader(args, True) |
| 153 | + print(f"Train dataset: {len(dloader.dataset):,} images") |
| 154 | + |
| 155 | + model = timm.create_model(args.model, pretrained=True, num_classes=45).cuda() |
| 156 | + if args.channels_last: |
| 157 | + model.to(memory_format=torch.channels_last) |
| 158 | + if args.compile: |
| 159 | + model.compile(fullgraph=True) |
| 160 | + print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") |
| 161 | + |
| 162 | + OPTIM_MAP = dict( |
| 163 | + Adam=torch.optim.Adam, |
| 164 | + Adam8bitBnb=bnb.optim.Adam8bit, |
| 165 | + Adam8bitAo=Adam8bit, |
| 166 | + ) |
| 167 | + optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay) |
| 168 | + lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs) |
| 169 | + |
| 170 | + grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16") |
| 171 | + |
| 172 | + step = 0 |
| 173 | + for epoch_idx in range(args.n_epochs): |
| 174 | + model.train() |
| 175 | + prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if args.profile else nullcontext() |
| 176 | + |
| 177 | + with prof: |
| 178 | + for batch in tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"): |
| 179 | + if args.channels_last: |
| 180 | + batch["image"] = batch["image"].to(memory_format=torch.channels_last) |
| 181 | + |
| 182 | + with get_amp_ctx(args.amp): |
| 183 | + loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda()) |
| 184 | + grad_scaler.scale(loss).backward() |
| 185 | + |
| 186 | + if args.cosine_lr_scheduler: |
| 187 | + lr = lr_schedule.get_lr(step) |
| 188 | + for param_group in optim.param_groups: |
| 189 | + param_group["lr"] = lr |
| 190 | + |
| 191 | + if step % 100 == 0: |
| 192 | + logger.log(dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]), step=step) |
| 193 | + |
| 194 | + grad_scaler.step(optim) |
| 195 | + grad_scaler.update() |
| 196 | + optim.zero_grad() |
| 197 | + |
| 198 | + step += 1 |
| 199 | + |
| 200 | + if args.profile and step == 20: |
| 201 | + break |
| 202 | + |
| 203 | + if args.profile: |
| 204 | + prof.export_chrome_trace("trace.json") |
| 205 | + |
| 206 | + else: |
| 207 | + val_acc = evaluate_model(model, args) |
| 208 | + print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}") |
| 209 | + logger.log(dict(val_acc=val_acc), step=step) |
| 210 | + |
| 211 | + print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / (1 << 30):.2f} GB") |
0 commit comments