Skip to content

Commit 3a2dcc1

Browse files
authored
Merge pull request #594 from cdoern/refactor-accelerator
feat: refactor main_ds.py (2/n) Accelerator class
2 parents b75fd4e + bf4be9f commit 3a2dcc1

File tree

4 files changed

+542
-193
lines changed

4 files changed

+542
-193
lines changed
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
# Standard
2+
from copy import deepcopy
3+
from typing import Callable, Optional
4+
5+
# Third Party
6+
from accelerate import Accelerator as TransformersAccel
7+
from torch.utils.data import DataLoader
8+
from transformers import get_scheduler
9+
import torch
10+
11+
# First Party
12+
from instructlab.training.config import ( # Adjust this import if needed
13+
DeepSpeedOptions,
14+
DistributedBackend,
15+
)
16+
17+
# Local
18+
from .model import Model
19+
20+
21+
class Accelerator:
22+
def __init__(
23+
self,
24+
model: Model,
25+
samples_per_gpu: int,
26+
grad_accum: int,
27+
train_loader: DataLoader,
28+
save_samples: int,
29+
distributed_framework: DistributedBackend, # dist framework is assoc with Accelerator primarily.
30+
fsdp_sharding_strategy: Optional[str] = None,
31+
deepspeed_cpu_offload_optimizer: Optional[bool] = False,
32+
deepspeed_cpu_offload_optimizer_pin_memory: Optional[bool] = False,
33+
deepspeed_cpu_offload_optimizer_ratio: Optional[float] = None,
34+
fsdp_cpu_offload_params: Optional[bool] = False,
35+
):
36+
self.samples_per_gpu = samples_per_gpu
37+
self.save_samples = save_samples
38+
self.grad_accum = grad_accum
39+
self.model = model
40+
self.distributed_framework = distributed_framework
41+
self.fsdp_sharding_strategy = fsdp_sharding_strategy
42+
self.deepspeed_cpu_offload_optimizer = deepspeed_cpu_offload_optimizer
43+
self.deepspeed_cpu_offload_optimizer_pin_memory = (
44+
deepspeed_cpu_offload_optimizer_pin_memory
45+
)
46+
self.train_loader = train_loader
47+
self.deepspeed_cpu_offload_optimizer_ratio = (
48+
deepspeed_cpu_offload_optimizer_ratio
49+
)
50+
self.fsdp_cpu_offload_params = fsdp_cpu_offload_params
51+
52+
if self.distributed_framework == DistributedBackend.DEEPSPEED:
53+
# Standard
54+
accel_args = {
55+
"deepspeed_plugin": self.get_ds_plugin(
56+
world_size=torch.distributed.get_world_size(),
57+
samples_per_gpu=samples_per_gpu,
58+
grad_accum=grad_accum,
59+
opts=DeepSpeedOptions(
60+
cpu_offload_optimizer=deepspeed_cpu_offload_optimizer,
61+
cpu_offload_optimizer_ratio=self.deepspeed_cpu_offload_optimizer_ratio,
62+
cpu_offload_optimizer_pin_memory=self.deepspeed_cpu_offload_optimizer_pin_memory,
63+
save_samples=save_samples,
64+
),
65+
),
66+
}
67+
elif self.distributed_framework == DistributedBackend.FSDP:
68+
accel_args = {
69+
"fsdp_plugin": self.get_fsdp_config(),
70+
"mixed_precision": "bf16",
71+
}
72+
self.accelerator = TransformersAccel(
73+
**accel_args,
74+
)
75+
self.accelerator.even_batches = False
76+
77+
new_m = self.accelerator.prepare(model.model)
78+
self.model.update_model(new_m)
79+
80+
def prepare_with_optimizer(
81+
self,
82+
optimizer: torch.optim.Optimizer,
83+
lr_scheduler: str,
84+
num_epochs: int,
85+
num_warmup_steps: int,
86+
):
87+
self.lr_scheduler: Callable
88+
self.setup_lr_scheduler(
89+
optimizer=optimizer,
90+
lr_scheduler=lr_scheduler,
91+
num_epochs=num_epochs,
92+
num_warmup_steps=num_warmup_steps,
93+
)
94+
new_m, new_opt, _, self.lr_scheduler = self.accelerator.prepare(
95+
self.model.model,
96+
optimizer,
97+
deepcopy(self.train_loader),
98+
self.lr_scheduler,
99+
)
100+
self.lr_scheduler.split_batches = True
101+
self.model.update_model(new_m)
102+
self.optimizer = new_opt
103+
104+
def setup_lr_scheduler(
105+
self,
106+
optimizer: torch.optim.Optimizer,
107+
lr_scheduler: str,
108+
num_epochs: int,
109+
num_warmup_steps: int,
110+
):
111+
self.lr_scheduler = get_scheduler(
112+
name=lr_scheduler,
113+
optimizer=optimizer,
114+
num_warmup_steps=num_warmup_steps,
115+
num_training_steps=num_epochs * len(self.train_loader) // self.grad_accum,
116+
)
117+
118+
def __getattr__(self, name):
119+
# Forward anything not found to the underlying optimizer
120+
return getattr(self.accelerator, name)
121+
122+
def get_fsdp_config(self):
123+
# Standard
124+
from functools import partial
125+
126+
# Third Party
127+
from accelerate.utils import FullyShardedDataParallelPlugin
128+
from peft.utils.other import fsdp_auto_wrap_policy
129+
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
130+
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
131+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
132+
133+
# First Party
134+
from instructlab.training.utils import get_module_class_from_name
135+
136+
is_lora = self.model.lora_config is not None
137+
block_name = self.model._no_split_modules[0]
138+
139+
wrap_policy = None
140+
if is_lora > 0:
141+
wrap_policy = fsdp_auto_wrap_policy(self.model)
142+
else:
143+
wrap_policy = partial(
144+
transformer_auto_wrap_policy,
145+
transformer_layer_cls={
146+
get_module_class_from_name(self.model, block_name),
147+
},
148+
)
149+
150+
# TODO(osilkin): BACKWARD_POST trades memory utilization for processing time, which is important for systems utilizing LoRA
151+
# We should have this be configurable in the future.
152+
prefetch_policy = (
153+
BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE
154+
)
155+
fsdp_plugin = FullyShardedDataParallelPlugin(
156+
auto_wrap_policy=wrap_policy,
157+
limit_all_gathers=True,
158+
backward_prefetch=prefetch_policy,
159+
sharding_strategy=ShardingStrategy[self.fsdp_sharding_strategy],
160+
cpu_offload=CPUOffload(self.fsdp_cpu_offload_params),
161+
)
162+
163+
# `use_orig_params` must be disabled when using LoRA and FSDP together
164+
# Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts
165+
if self.model.lora_config is not None:
166+
fsdp_plugin.use_orig_params = False
167+
168+
return fsdp_plugin
169+
170+
def get_ds_plugin(
171+
self, world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions
172+
):
173+
# Third Party
174+
from accelerate.utils import DeepSpeedPlugin
175+
176+
ds_config = {
177+
"train_batch_size": samples_per_gpu * world_size * grad_accum,
178+
"gradient_accumulation_steps": grad_accum,
179+
"train_micro_batch_size_per_gpu": samples_per_gpu,
180+
"steps_per_print": 1,
181+
"zero_optimization": {
182+
"stage": 2,
183+
# this option is only supported with DeepSpeed ZeRO stage 3
184+
"offload_param": {"device": "none"},
185+
"offload_optimizer": {"device": "none"},
186+
},
187+
"bf16": {"enabled": True},
188+
"gradient_clipping": 1.0,
189+
"prescale_gradients": False,
190+
"wall_clock_breakdown": False,
191+
}
192+
193+
if opts.cpu_offload_optimizer:
194+
# this only works when the cpu offload optimizer is enabled
195+
ds_config["zero_optimization"]["offload_optimizer"] = {
196+
# CPU offloading is the only option available in ZeRO stage 2
197+
"device": "cpu",
198+
"pin_memory": opts.cpu_offload_optimizer_pin_memory,
199+
"ratio": opts.cpu_offload_optimizer_ratio,
200+
}
201+
ds_plugin = DeepSpeedPlugin(
202+
hf_ds_config=ds_config,
203+
)
204+
return ds_plugin
205+
206+
@classmethod
207+
def setup_deepspeed(
208+
cls,
209+
model: Model,
210+
samples_per_gpu: int,
211+
grad_accum: int,
212+
train_loader: DataLoader,
213+
deepspeed_cpu_offload_optimizer: Optional[bool],
214+
deepspeed_cpu_offload_optimizer_pin_memory: Optional[bool],
215+
deepspeed_cpu_offload_optimizer_ratio: float,
216+
save_samples: int,
217+
):
218+
return cls(
219+
model=model,
220+
grad_accum=grad_accum,
221+
train_loader=train_loader,
222+
distributed_framework=DistributedBackend.DEEPSPEED,
223+
samples_per_gpu=samples_per_gpu,
224+
deepspeed_cpu_offload_optimizer=deepspeed_cpu_offload_optimizer,
225+
deepspeed_cpu_offload_optimizer_pin_memory=deepspeed_cpu_offload_optimizer_pin_memory,
226+
deepspeed_cpu_offload_optimizer_ratio=deepspeed_cpu_offload_optimizer_ratio,
227+
save_samples=save_samples,
228+
)
229+
230+
@classmethod
231+
def setup_fsdp(
232+
cls,
233+
model: Model,
234+
samples_per_gpu: int,
235+
grad_accum: int,
236+
train_loader: DataLoader,
237+
fsdp_sharding_strategy: Optional[str],
238+
fsdp_cpu_offload_params: bool,
239+
save_samples: int,
240+
):
241+
return cls(
242+
model=model,
243+
grad_accum=grad_accum,
244+
train_loader=train_loader,
245+
distributed_framework=DistributedBackend.FSDP,
246+
samples_per_gpu=samples_per_gpu,
247+
fsdp_sharding_strategy=fsdp_sharding_strategy,
248+
fsdp_cpu_offload_params=fsdp_cpu_offload_params,
249+
save_samples=save_samples,
250+
)

0 commit comments

Comments
 (0)