Skip to content

Commit 739952b

Browse files
authored
8-bit Adam (#463)
1 parent d1e15b4 commit 739952b

8 files changed

Lines changed: 806 additions & 0 deletions

File tree

benchmarks/benchmark_adam_8bit.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# pip install timm wandb tqdm datasets
2+
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core
3+
#
4+
# python benchmarks_adam_8bit.py \
5+
# --model "timm/vit_base_patch16_224.augreg_in21k" \
6+
# --amp bf16 \
7+
# --optim Adam
8+
#
9+
# To use bnb 8-bit optimizer, set --optim Adam8bitBnb. To use 8-bit optimizer implemented in torchao, set --optim Adam8bitAo
10+
# To profile and export chrome trace, set --profile
11+
# To enable cosine learning rate scheduler, set --cosine_lr_scheduler
12+
13+
import argparse
14+
import math
15+
from contextlib import nullcontext
16+
from pathlib import Path
17+
18+
import bitsandbytes as bnb
19+
import datasets
20+
import timm
21+
import torch
22+
import torch.nn.functional as F
23+
from torch.profiler import ProfilerActivity, profile
24+
from torch.utils.data import DataLoader
25+
from torchvision.transforms import v2
26+
from tqdm import tqdm
27+
28+
from torchao.prototype.optim_8bit import Adam8bit
29+
30+
31+
class CosineSchedule:
32+
def __init__(self, lr: float, total_steps: int, warmup: float = 0.05) -> None:
33+
self.lr = lr
34+
self.final_lr = 0
35+
self.total_steps = total_steps
36+
self.warmup_steps = round(total_steps * warmup)
37+
38+
def get_lr(self, step: int) -> float:
39+
if step < self.warmup_steps:
40+
return self.lr * step / self.warmup_steps
41+
if step < self.total_steps:
42+
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
43+
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi))
44+
return self.final_lr
45+
46+
47+
class WandbLogger:
48+
def __init__(self, args):
49+
if args.project is not None and not args.profile:
50+
import wandb
51+
52+
Path("wandb_logs").mkdir(exist_ok=True)
53+
self.run = wandb.init(project=args.project, name=args.run_name, config=args, dir="wandb_logs")
54+
55+
else:
56+
self.run = None
57+
58+
def log(self, *args, **kwargs):
59+
if self.run is not None:
60+
self.run.log(*args, **kwargs)
61+
62+
63+
def get_parser():
64+
parser = argparse.ArgumentParser()
65+
parser.add_argument("--model", required=True)
66+
67+
parser.add_argument("--amp", default="none")
68+
parser.add_argument("--channels_last", action="store_true")
69+
parser.add_argument("--compile", action="store_true")
70+
71+
parser.add_argument("--n_epochs", type=int, default=10)
72+
parser.add_argument("--batch_size", type=int, default=64)
73+
parser.add_argument("--n_workers", type=int, default=4)
74+
75+
parser.add_argument("--optim", default="Adam")
76+
parser.add_argument("--lr", type=float, default=1e-4)
77+
parser.add_argument("--weight_decay", type=float, default=0)
78+
parser.add_argument("--cosine_lr_scheduler", action="store_true")
79+
80+
parser.add_argument("--project")
81+
parser.add_argument("--run_name", default="debug")
82+
parser.add_argument("--profile", action="store_true")
83+
return parser
84+
85+
86+
def get_dloader(args, training: bool):
87+
transforms = [v2.ToImage()]
88+
89+
if training:
90+
transforms.extend([v2.RandomResizedCrop(224), v2.RandomHorizontalFlip()])
91+
else:
92+
transforms.extend([v2.Resize(256), v2.CenterCrop(224)])
93+
94+
transforms.append(v2.ToDtype(torch.float32, scale=True))
95+
transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
96+
transforms = v2.Compose(transforms)
97+
98+
# use dataset from HF so download is fast
99+
ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation")
100+
ds = ds.select_columns(["image", "label"])
101+
ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"]))
102+
103+
return DataLoader(
104+
ds,
105+
batch_size=args.batch_size,
106+
shuffle=training,
107+
num_workers=args.n_workers,
108+
pin_memory=training,
109+
drop_last=training,
110+
)
111+
112+
113+
def get_amp_ctx(amp):
114+
dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp]
115+
return torch.autocast("cuda", dtype=dtype, enabled=amp != "none")
116+
117+
118+
@torch.no_grad()
119+
def evaluate_model(model, args):
120+
model.eval()
121+
val_dloader = get_dloader(args, False)
122+
123+
all_labels = []
124+
all_preds = []
125+
126+
for batch in tqdm(val_dloader, dynamic_ncols=True, desc=f"Evaluating"):
127+
all_labels.append(batch["label"].clone())
128+
if args.channels_last:
129+
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
130+
131+
with get_amp_ctx(args.amp):
132+
all_preds.append(model(batch["image"].cuda()).argmax(1).cpu())
133+
134+
all_labels = torch.cat(all_labels, dim=0)
135+
all_preds = torch.cat(all_preds, dim=0)
136+
137+
acc = (all_labels == all_preds).float().mean()
138+
return acc
139+
140+
141+
if __name__ == "__main__":
142+
args = get_parser().parse_args()
143+
144+
if args.profile:
145+
args.n_epochs = 1
146+
147+
for k, v in vars(args).items():
148+
print(f"{k}: {v}")
149+
150+
# wandb is only enabled when args.project is set and args.profile is False
151+
logger = WandbLogger(args)
152+
dloader = get_dloader(args, True)
153+
print(f"Train dataset: {len(dloader.dataset):,} images")
154+
155+
model = timm.create_model(args.model, pretrained=True, num_classes=45).cuda()
156+
if args.channels_last:
157+
model.to(memory_format=torch.channels_last)
158+
if args.compile:
159+
model.compile(fullgraph=True)
160+
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
161+
162+
OPTIM_MAP = dict(
163+
Adam=torch.optim.Adam,
164+
Adam8bitBnb=bnb.optim.Adam8bit,
165+
Adam8bitAo=Adam8bit,
166+
)
167+
optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay)
168+
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
169+
170+
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
171+
172+
step = 0
173+
for epoch_idx in range(args.n_epochs):
174+
model.train()
175+
prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if args.profile else nullcontext()
176+
177+
with prof:
178+
for batch in tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"):
179+
if args.channels_last:
180+
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
181+
182+
with get_amp_ctx(args.amp):
183+
loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda())
184+
grad_scaler.scale(loss).backward()
185+
186+
if args.cosine_lr_scheduler:
187+
lr = lr_schedule.get_lr(step)
188+
for param_group in optim.param_groups:
189+
param_group["lr"] = lr
190+
191+
if step % 100 == 0:
192+
logger.log(dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]), step=step)
193+
194+
grad_scaler.step(optim)
195+
grad_scaler.update()
196+
optim.zero_grad()
197+
198+
step += 1
199+
200+
if args.profile and step == 20:
201+
break
202+
203+
if args.profile:
204+
prof.export_chrome_trace("trace.json")
205+
206+
else:
207+
val_acc = evaluate_model(model, args)
208+
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
209+
logger.log(dict(val_acc=val_acc), step=step)
210+
211+
print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / (1 << 30):.2f} GB")

test/prototype/test_optim_8bit.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import copy
2+
3+
import pytest
4+
import torch
5+
from torch import nn
6+
from torch.testing._internal.common_utils import (
7+
TestCase,
8+
instantiate_parametrized_tests,
9+
parametrize,
10+
run_tests,
11+
)
12+
from torchao.prototype import optim_8bit
13+
from torchao.prototype.optim_8bit.subclass import quantize_8bit_with_qmap, QMAP_SIGNED
14+
from torchao.utils import TORCH_VERSION_AFTER_2_3
15+
16+
try:
17+
import bitsandbytes as bnb
18+
except ImportError:
19+
bnb = None
20+
21+
22+
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
23+
24+
25+
class TestDTQ8bit(TestCase):
26+
@parametrize("device", _DEVICES)
27+
def test_quantize_8bit_with_qmap_correctness(self, device):
28+
x = torch.randn(32, 1024, device=device)
29+
qmap = torch.tensor(QMAP_SIGNED, device=device)
30+
31+
actual_codes, actual_scale = quantize_8bit_with_qmap(x, qmap, 256, implementation=1)
32+
expected_codes, expected_scale = quantize_8bit_with_qmap(x, qmap, 256, implementation=0)
33+
34+
torch.testing.assert_close(actual_codes, expected_codes)
35+
torch.testing.assert_close(actual_scale, expected_scale)
36+
37+
@parametrize("device", _DEVICES)
38+
def test_quantize_8bit_with_qmap_compile(self, device):
39+
x = torch.randn(32, 1024, device=device)
40+
qmap = torch.tensor(QMAP_SIGNED, device=device)
41+
42+
actual_codes, actual_scale = torch.compile(quantize_8bit_with_qmap, fullgraph=True)(x, qmap, 256)
43+
expected_codes, expected_scale = quantize_8bit_with_qmap(x, qmap, 256)
44+
45+
torch.testing.assert_close(actual_codes, expected_codes)
46+
torch.testing.assert_close(actual_scale, expected_scale)
47+
48+
49+
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
50+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
51+
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
52+
class TestOptim8bit(TestCase):
53+
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
54+
def test_adam_8bit_correctness(self, optim_name):
55+
device = "cuda"
56+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
57+
model2 = copy.deepcopy(model1)
58+
59+
optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
60+
optim2 = getattr(optim_8bit, optim_name)(model2.parameters())
61+
62+
for _ in range(2):
63+
x = torch.randn(4, 32, device=device)
64+
65+
loss1 = model1(x).sum()
66+
loss1.backward()
67+
optim1.step()
68+
optim1.zero_grad()
69+
70+
loss2 = model2(x).sum()
71+
loss2.backward()
72+
optim2.step()
73+
optim2.zero_grad()
74+
75+
for p1, p2 in zip(model1.parameters(), model2.parameters()):
76+
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)
77+
78+
79+
instantiate_parametrized_tests(TestDTQ8bit)
80+
instantiate_parametrized_tests(TestOptim8bit)
81+
82+
83+
if __name__ == "__main__":
84+
run_tests()

torchao/prototype/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm
1111
- `galore/docs` - implementation notes and discussion of issues faced in kernel design.
1212
- [`quant_llm`](quant_llm) - FP16 x FPx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112)
13+
- [`optim_8bit`](optim_8bit) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
1314

1415
#### Roadmap
1516

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# 8-bit optimizers
2+
3+
This folder implements 8-bit optimizers using dynamic tree quantization as outlined in https://arxiv.org/abs/2110.02861. The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel.
4+
5+
## Usage
6+
7+
This is a drop-in replacement for `torch.optim.Adam`
8+
9+
```python
10+
from torchao.prototype.optim_8bit import Adam8bit
11+
12+
model = ...
13+
optim = Adam8bit(model.parameters())
14+
```
15+
16+
You can also change quantization block size (default 2048) by passing `block_size=value` to the optimizer.
17+
18+
**Other optimizers**: AdamW is also available as `AdamW8bit`.
19+
20+
NOTE: this requires PyTorch >= 2.3
21+
22+
## Benchmarks
23+
24+
Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_adam_8bit.py](../../../benchmarks/benchmark_adam_8bit.py).
25+
26+
Results for fine-tuning ViT-B with BF16 AMP, on 4070Ti SUPER:
27+
28+
Adam impl | max memory (GB) | training time | accuracy
29+
----------|-----------------|---------------|----------
30+
PyTorch | 5.26 | 9m 11s | 93.62%
31+
bnb 8-bit | 4.78 | 9m 10s | 93.06%
32+
ao 8-bit | 4.78 | 9m 15s | 94.14%
33+
34+
**Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed
35+
36+
## Credits
37+
38+
Credits to Tim Dettmers for creating the wonderful bitsandbytes library.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .adam import Adam8bit
2+
from .adamw import AdamW8bit

0 commit comments

Comments
 (0)