diff --git a/src/instructlab/training/accelerator.py b/src/instructlab/training/accelerator.py new file mode 100644 index 00000000..b03c4a45 --- /dev/null +++ b/src/instructlab/training/accelerator.py @@ -0,0 +1,250 @@ +# Standard +from copy import deepcopy +from typing import Callable, Optional + +# Third Party +from accelerate import Accelerator as TransformersAccel +from torch.utils.data import DataLoader +from transformers import get_scheduler +import torch + +# First Party +from instructlab.training.config import ( # Adjust this import if needed + DeepSpeedOptions, + DistributedBackend, +) + +# Local +from .model import Model + + +class Accelerator: + def __init__( + self, + model: Model, + samples_per_gpu: int, + grad_accum: int, + train_loader: DataLoader, + save_samples: int, + distributed_framework: DistributedBackend, # dist framework is assoc with Accelerator primarily. + fsdp_sharding_strategy: Optional[str] = None, + deepspeed_cpu_offload_optimizer: Optional[bool] = False, + deepspeed_cpu_offload_optimizer_pin_memory: Optional[bool] = False, + deepspeed_cpu_offload_optimizer_ratio: Optional[float] = None, + fsdp_cpu_offload_params: Optional[bool] = False, + ): + self.samples_per_gpu = samples_per_gpu + self.save_samples = save_samples + self.grad_accum = grad_accum + self.model = model + self.distributed_framework = distributed_framework + self.fsdp_sharding_strategy = fsdp_sharding_strategy + self.deepspeed_cpu_offload_optimizer = deepspeed_cpu_offload_optimizer + self.deepspeed_cpu_offload_optimizer_pin_memory = ( + deepspeed_cpu_offload_optimizer_pin_memory + ) + self.train_loader = train_loader + self.deepspeed_cpu_offload_optimizer_ratio = ( + deepspeed_cpu_offload_optimizer_ratio + ) + self.fsdp_cpu_offload_params = fsdp_cpu_offload_params + + if self.distributed_framework == DistributedBackend.DEEPSPEED: + # Standard + accel_args = { + "deepspeed_plugin": self.get_ds_plugin( + world_size=torch.distributed.get_world_size(), + samples_per_gpu=samples_per_gpu, + grad_accum=grad_accum, + opts=DeepSpeedOptions( + cpu_offload_optimizer=deepspeed_cpu_offload_optimizer, + cpu_offload_optimizer_ratio=self.deepspeed_cpu_offload_optimizer_ratio, + cpu_offload_optimizer_pin_memory=self.deepspeed_cpu_offload_optimizer_pin_memory, + save_samples=save_samples, + ), + ), + } + elif self.distributed_framework == DistributedBackend.FSDP: + accel_args = { + "fsdp_plugin": self.get_fsdp_config(), + "mixed_precision": "bf16", + } + self.accelerator = TransformersAccel( + **accel_args, + ) + self.accelerator.even_batches = False + + new_m = self.accelerator.prepare(model.model) + self.model.update_model(new_m) + + def prepare_with_optimizer( + self, + optimizer: torch.optim.Optimizer, + lr_scheduler: str, + num_epochs: int, + num_warmup_steps: int, + ): + self.lr_scheduler: Callable + self.setup_lr_scheduler( + optimizer=optimizer, + lr_scheduler=lr_scheduler, + num_epochs=num_epochs, + num_warmup_steps=num_warmup_steps, + ) + new_m, new_opt, _, self.lr_scheduler = self.accelerator.prepare( + self.model.model, + optimizer, + deepcopy(self.train_loader), + self.lr_scheduler, + ) + self.lr_scheduler.split_batches = True + self.model.update_model(new_m) + self.optimizer = new_opt + + def setup_lr_scheduler( + self, + optimizer: torch.optim.Optimizer, + lr_scheduler: str, + num_epochs: int, + num_warmup_steps: int, + ): + self.lr_scheduler = get_scheduler( + name=lr_scheduler, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_epochs * len(self.train_loader) // self.grad_accum, + ) + + def __getattr__(self, name): + # Forward anything not found to the underlying optimizer + return getattr(self.accelerator, name) + + def get_fsdp_config(self): + # Standard + from functools import partial + + # Third Party + from accelerate.utils import FullyShardedDataParallelPlugin + from peft.utils.other import fsdp_auto_wrap_policy + from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload + from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + + # First Party + from instructlab.training.utils import get_module_class_from_name + + is_lora = self.model.lora_config is not None + block_name = self.model._no_split_modules[0] + + wrap_policy = None + if is_lora > 0: + wrap_policy = fsdp_auto_wrap_policy(self.model) + else: + wrap_policy = partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + get_module_class_from_name(self.model, block_name), + }, + ) + + # TODO(osilkin): BACKWARD_POST trades memory utilization for processing time, which is important for systems utilizing LoRA + # We should have this be configurable in the future. + prefetch_policy = ( + BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE + ) + fsdp_plugin = FullyShardedDataParallelPlugin( + auto_wrap_policy=wrap_policy, + limit_all_gathers=True, + backward_prefetch=prefetch_policy, + sharding_strategy=ShardingStrategy[self.fsdp_sharding_strategy], + cpu_offload=CPUOffload(self.fsdp_cpu_offload_params), + ) + + # `use_orig_params` must be disabled when using LoRA and FSDP together + # Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts + if self.model.lora_config is not None: + fsdp_plugin.use_orig_params = False + + return fsdp_plugin + + def get_ds_plugin( + self, world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions + ): + # Third Party + from accelerate.utils import DeepSpeedPlugin + + ds_config = { + "train_batch_size": samples_per_gpu * world_size * grad_accum, + "gradient_accumulation_steps": grad_accum, + "train_micro_batch_size_per_gpu": samples_per_gpu, + "steps_per_print": 1, + "zero_optimization": { + "stage": 2, + # this option is only supported with DeepSpeed ZeRO stage 3 + "offload_param": {"device": "none"}, + "offload_optimizer": {"device": "none"}, + }, + "bf16": {"enabled": True}, + "gradient_clipping": 1.0, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + + if opts.cpu_offload_optimizer: + # this only works when the cpu offload optimizer is enabled + ds_config["zero_optimization"]["offload_optimizer"] = { + # CPU offloading is the only option available in ZeRO stage 2 + "device": "cpu", + "pin_memory": opts.cpu_offload_optimizer_pin_memory, + "ratio": opts.cpu_offload_optimizer_ratio, + } + ds_plugin = DeepSpeedPlugin( + hf_ds_config=ds_config, + ) + return ds_plugin + + @classmethod + def setup_deepspeed( + cls, + model: Model, + samples_per_gpu: int, + grad_accum: int, + train_loader: DataLoader, + deepspeed_cpu_offload_optimizer: Optional[bool], + deepspeed_cpu_offload_optimizer_pin_memory: Optional[bool], + deepspeed_cpu_offload_optimizer_ratio: float, + save_samples: int, + ): + return cls( + model=model, + grad_accum=grad_accum, + train_loader=train_loader, + distributed_framework=DistributedBackend.DEEPSPEED, + samples_per_gpu=samples_per_gpu, + deepspeed_cpu_offload_optimizer=deepspeed_cpu_offload_optimizer, + deepspeed_cpu_offload_optimizer_pin_memory=deepspeed_cpu_offload_optimizer_pin_memory, + deepspeed_cpu_offload_optimizer_ratio=deepspeed_cpu_offload_optimizer_ratio, + save_samples=save_samples, + ) + + @classmethod + def setup_fsdp( + cls, + model: Model, + samples_per_gpu: int, + grad_accum: int, + train_loader: DataLoader, + fsdp_sharding_strategy: Optional[str], + fsdp_cpu_offload_params: bool, + save_samples: int, + ): + return cls( + model=model, + grad_accum=grad_accum, + train_loader=train_loader, + distributed_framework=DistributedBackend.FSDP, + samples_per_gpu=samples_per_gpu, + fsdp_sharding_strategy=fsdp_sharding_strategy, + fsdp_cpu_offload_params=fsdp_cpu_offload_params, + save_samples=save_samples, + ) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index ccb417d1..4ca638d0 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # Standard -from copy import deepcopy import argparse import datetime import logging @@ -10,9 +9,6 @@ import time import warnings -# Third Party -from accelerate import Accelerator - try: # Third Party from deepspeed.ops.adam import DeepSpeedCPUAdam @@ -38,14 +34,14 @@ ) # Third Party -from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import AutoConfig, PreTrainedTokenizer, get_scheduler +from transformers import AutoConfig import torch import torch.distributed # First Party from instructlab.training import config +from instructlab.training.accelerator import Accelerator from instructlab.training.config import ( DistributedBackend, ModelTypes, @@ -69,7 +65,6 @@ from instructlab.training.multipack_sampler import ( find_packing_max_batch_len_and_grad_accum, ) -from instructlab.training.setup_accelerator import setup_accelerator from instructlab.training.token_dataset import setup_dataloader, setup_dataset from instructlab.training.tokenizer_utils import setup_tokenizer from instructlab.training.utils import ( @@ -87,13 +82,9 @@ def train( args, - model, - optimizer, - lr_scheduler, + model: Model, + optimizer: torch.optim.Optimizer, accelerator: Accelerator, - tokenizer: PreTrainedTokenizer, - train_loader: DataLoader, - grad_accum, ): model.train() @@ -104,15 +95,15 @@ def train( metric_logger = logging.getLogger("instructlab.training.metrics") base_logger = logging.getLogger("instructlab.training") - batch_size = args.effective_batch_size // grad_accum + batch_size = args.effective_batch_size // accelerator.grad_accum samples_seen = 0 if hasattr(args, "samples_seen"): logger.info("Updating 'samples_seen' %d", args.samples_seen) samples_seen = args.samples_seen - if args.save_samples > 0: - args.save_samples = (args.save_samples // batch_size) * batch_size + if accelerator.save_samples > 0: + accelerator.save_samples = (accelerator.save_samples // batch_size) * batch_size logger.info("Number of samples per save: %d", args.save_samples) if args.save_samples_ds is not None: @@ -122,18 +113,18 @@ def train( global_grad_norm = None for epoch in range(args.current_epoch, args.num_epochs): if args.sampler in ("multipack"): - train_loader.batch_sampler.set_epoch(epoch) + accelerator.train_loader.batch_sampler.set_epoch(epoch) elif args.sampler in ("distributed"): - train_loader.sampler.set_epoch(epoch) + accelerator.train_loader.sampler.set_epoch(epoch) else: raise NotADirectoryError - num_epoch_steps = len(train_loader) + num_epoch_steps = len(accelerator.train_loader) if local_rank == 0: inner_pb = tqdm(range(num_epoch_steps), desc=f"Epoch {epoch}") # blast through the batches in the train loader up to the last step within the epoch. - for batch in train_loader: + for batch in accelerator.train_loader: if global_step <= args.last_step: # in the case of resuming, last_step > 0 global_step += 1 @@ -178,16 +169,16 @@ def train( ) accelerator.backward(loss) - if global_step % grad_accum == 0: + if global_step % accelerator.grad_accum == 0: global_grad_norm = accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() - lr_scheduler.step() + accelerator.lr_scheduler.step() optimizer.zero_grad() if local_rank == 0: elapsed_time = time.time() - start overall_throughput = args.samples_per_gpu * world_size / elapsed_time - current_lr = lr_scheduler.get_last_lr()[0] + current_lr = accelerator.lr_scheduler.get_last_lr()[0] cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3) cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] global_grad_norm = ( @@ -219,7 +210,7 @@ def train( "total_loss": float(log_loss / num_loss_counted_tokens), "samples_seen": samples_seen, "gradnorm": global_grad_norm, - "total_samples": len(train_loader.dataset), + "total_samples": len(accelerator.train_loader.dataset), "num_epoch_steps": num_epoch_steps, # "weight_norm": weight_norm, }, @@ -234,7 +225,7 @@ def train( args=args, accelerator=accelerator, model=model, - tokenizer=tokenizer, + tokenizer=model.tokenizer, samples_seen=samples_seen, is_lora=bool(args.lora_r), hf_format=True, @@ -252,7 +243,7 @@ def train( args=args, accelerator=accelerator, model=model, - tokenizer=tokenizer, + tokenizer=model.tokenizer, samples_seen=samples_seen, is_lora=bool(args.lora_r), full_state=args.accelerate_full_state_at_epoch, @@ -266,7 +257,7 @@ def train( save_hf_format_accelerate( args, model, - tokenizer, + model.tokenizer, accelerator, samples_seen, is_lora=bool(args.lora_r), @@ -460,51 +451,45 @@ def main(args): }, extra={"hparams": True}, ) - - # TODO cdoern: `m.model` should be hidden behind a custom `Accelerator class` - # when using the training library `Model` class should be first class - accelerator = setup_accelerator(args, m, grad_accum) - if args.distributed_training_framework == DistributedBackend.FSDP.value: - model = accelerator.prepare(m.model) - m.update_model(model) + # accelerator does not need optimizer to init, in fact, the optimizer needs to be initialized AFTER the Accelerator + accelerator = Accelerator( + model=m, + samples_per_gpu=args.samples_per_gpu, + grad_accum=grad_accum, + train_loader=train_loader, + distributed_framework=DistributedBackend(args.distributed_training_framework), + fsdp_sharding_strategy=args.fsdp_sharding_strategy, + deepspeed_cpu_offload_optimizer=args.cpu_offload_optimizer, + deepspeed_cpu_offload_optimizer_pin_memory=args.cpu_offload_optimizer_pin_memory, + deepspeed_cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio, + fsdp_cpu_offload_params=args.cpu_offload_params_fsdp, + save_samples=args.save_samples, + ) + # optimizer needs model that has been prepared by accelerator + # and then accelerator needs to be prepared AGAIN once optimizer is initialized optimizer = setup_optimizer( model=m, cpu_offload=args.cpu_offload_optimizer, name=None, # choose based on backend learning_rate=args.learning_rate, ) - - lr_scheduler = get_scheduler( - name=args.lr_scheduler, + accelerator.prepare_with_optimizer( optimizer=optimizer, + lr_scheduler=args.lr_scheduler, + num_epochs=args.num_epochs, num_warmup_steps=args.num_warmup_steps, - num_training_steps=args.num_epochs * len(train_loader) // grad_accum, - ) - - # TODO cdoern: `m.model` should be hidden behind a custom `Accelerator class` - # when using the training library `Model` class should be first class - model, optimizer, _, lr_scheduler = accelerator.prepare( - m.model, - optimizer, - deepcopy(train_loader), - lr_scheduler, ) - m.update_model(model) - # Necessary so that Accelerate does not step once per GPU - # see https://github.com/huggingface/accelerate/blob/127818fc27ebe5cb236357fff59ff1748326d643/src/accelerate/scheduler.py#L69 - lr_scheduler.split_batches = True + # TODO: make this work more seamlessly + optimizer = accelerator.optimizer + m = accelerator.model load_latest_full_state(args=args, accelerator=accelerator) train( args, - m, - optimizer, - lr_scheduler, - accelerator, - tokenizer, - train_loader, - grad_accum, + model=m, + optimizer=optimizer, + accelerator=accelerator, ) torch.distributed.barrier() diff --git a/src/instructlab/training/setup_accelerator.py b/src/instructlab/training/setup_accelerator.py deleted file mode 100644 index 55e1566a..00000000 --- a/src/instructlab/training/setup_accelerator.py +++ /dev/null @@ -1,135 +0,0 @@ -# Standard -from functools import partial - -# Third Party -from accelerate import Accelerator -from peft.utils.other import fsdp_auto_wrap_policy -from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from transformers import PreTrainedModel -import torch - -# First Party -from instructlab.training.config import DeepSpeedOptions -from instructlab.training.utils import get_module_class_from_name, patch_target_module - - -def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions): - # Third Party - from accelerate.utils import DeepSpeedPlugin - - ds_config = { - "train_batch_size": samples_per_gpu * world_size * grad_accum, - "gradient_accumulation_steps": grad_accum, - "train_micro_batch_size_per_gpu": samples_per_gpu, - "steps_per_print": 1, - "zero_optimization": { - "stage": 2, - # this option is only supported with DeepSpeed ZeRO stage 3 - "offload_param": {"device": "none"}, - "offload_optimizer": {"device": "none"}, - }, - "bf16": {"enabled": True}, - "gradient_clipping": 1.0, - "prescale_gradients": False, - "wall_clock_breakdown": False, - } - - if opts.cpu_offload_optimizer: - # this only works when the cpu offload optimizer is enabled - ds_config["zero_optimization"]["offload_optimizer"] = { - # CPU offloading is the only option available in ZeRO stage 2 - "device": "cpu", - "pin_memory": opts.cpu_offload_optimizer_pin_memory, - "ratio": opts.cpu_offload_optimizer_ratio, - } - ds_plugin = DeepSpeedPlugin( - hf_ds_config=ds_config, - ) - return ds_plugin - - -def get_fsdp_config(args, model: PreTrainedModel): - # Third Party - from accelerate.utils import FullyShardedDataParallelPlugin - from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload - - is_lora = args.lora_r > 0 - block_name = model._no_split_modules[0] - - wrap_policy = None - if is_lora > 0: - wrap_policy = fsdp_auto_wrap_policy(model) - else: - wrap_policy = partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ - get_module_class_from_name(model, block_name), - }, - ) - - # TODO(osilkin): BACKWARD_POST trades memory utilization for processing time, which is important for systems utilizing LoRA - # We should have this be configurable in the future. - prefetch_policy = ( - BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE - ) - fsdp_plugin = FullyShardedDataParallelPlugin( - auto_wrap_policy=wrap_policy, - limit_all_gathers=True, - backward_prefetch=prefetch_policy, - sharding_strategy=ShardingStrategy[args.fsdp_sharding_strategy], - cpu_offload=CPUOffload(args.cpu_offload_params_fsdp), - ) - - # `use_orig_params` must be disabled when using LoRA and FSDP together - # Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts - if args.lora_r > 0: - fsdp_plugin.use_orig_params = False - - return fsdp_plugin - - -def setup_accelerator(args, model: PreTrainedModel, grad_accum): - if args.distributed_training_framework == "deepspeed": - try: - # Third Party - from deepspeed import DeepSpeedEngine - except ImportError as exc: - raise ImportError( - "DeepSpeed selected as distributed framework, but not installed" - ) from exc - - # patch deepspeed to work with quantized models. - if args.lora_quant_bits is not None: - patch_target_module( - "deepspeed.DeepSpeedEngine", - partial(DeepSpeedEngine, dont_change_device=True), - ) - - accel_args = { - "deepspeed_plugin": get_ds_plugin( - world_size=torch.distributed.get_world_size(), - samples_per_gpu=args.samples_per_gpu, - grad_accum=grad_accum, - opts=DeepSpeedOptions( - cpu_offload_optimizer=args.cpu_offload_optimizer, - cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio, - cpu_offload_optimizer_pin_memory=args.cpu_offload_optimizer_pin_memory, - save_samples=args.save_samples_ds, - ), - ), - } - elif args.distributed_training_framework == "fsdp": - accel_args = { - "fsdp_plugin": get_fsdp_config(args, model), - "mixed_precision": "bf16", - } - else: - raise ValueError( - f"Unknown sharding framework: {args.distributed_training_framework}" - ) - accelerator = Accelerator( - **accel_args, - ) - accelerator.even_batches = False - return accelerator diff --git a/tests/unit/test_accelerator.py b/tests/unit/test_accelerator.py new file mode 100644 index 00000000..c208580d --- /dev/null +++ b/tests/unit/test_accelerator.py @@ -0,0 +1,249 @@ +# Standard +from unittest.mock import MagicMock, patch +import os + +# Third Party +from torch.utils.data import DataLoader +import pytest +import torch + +# First Party +from instructlab.training.accelerator import Accelerator +from instructlab.training.config import DeepSpeedOptions, DistributedBackend +from instructlab.training.model import Model + + +@pytest.fixture +def mock_model(): + model = MagicMock(spec=Model) + model.model = MagicMock() + model.lora_config = None + model._no_split_modules = ["LlamaDecoderLayer"] + # Add children method to model + model.children = MagicMock(return_value=[]) + model.model.children = MagicMock(return_value=[]) + # Add get_module_class_from_name method + model.get_module_class_from_name = MagicMock(return_value=torch.nn.Module) + return model + + +@pytest.fixture +def mock_train_loader(): + loader = MagicMock(spec=DataLoader) + loader.dataset = MagicMock() + return loader + + +@pytest.fixture +def mock_optimizer(): + optimizer = MagicMock(spec=torch.optim.Optimizer) + # Add param_groups attribute with required keys + optimizer.param_groups = [{"params": [], "lr": 1e-4}] + return optimizer + + +@pytest.fixture +def mock_transformers_accel(): + with patch("instructlab.training.accelerator.TransformersAccel") as mock: + yield mock + + +def test_accelerator_init_deepspeed( + mock_model, mock_train_loader, mock_transformers_accel +): + with patch("torch.distributed.get_world_size", return_value=2): + accelerator = Accelerator( + model=mock_model, + samples_per_gpu=8, + grad_accum=2, + train_loader=mock_train_loader, + save_samples=1000, + distributed_framework=DistributedBackend.DEEPSPEED, + deepspeed_cpu_offload_optimizer_ratio=1.0, # Add default value + ) + + assert accelerator.samples_per_gpu == 8 + assert accelerator.grad_accum == 2 + assert accelerator.model == mock_model + assert accelerator.distributed_framework == DistributedBackend.DEEPSPEED + assert accelerator.train_loader == mock_train_loader + assert accelerator.save_samples == 1000 + + +def test_accelerator_init_fsdp(mock_model, mock_train_loader, mock_transformers_accel): + with patch("torch.distributed.get_world_size", return_value=2): + accelerator = Accelerator( + model=mock_model, + samples_per_gpu=8, + grad_accum=2, + train_loader=mock_train_loader, + save_samples=1000, + distributed_framework=DistributedBackend.FSDP, + fsdp_sharding_strategy="HYBRID_SHARD", + ) + + assert accelerator.samples_per_gpu == 8 + assert accelerator.grad_accum == 2 + assert accelerator.model == mock_model + assert accelerator.distributed_framework == DistributedBackend.FSDP + assert accelerator.fsdp_sharding_strategy == "HYBRID_SHARD" + + +def test_accelerator_prepare_with_optimizer( + mock_model, mock_train_loader, mock_optimizer, mock_transformers_accel +): + with patch("torch.distributed.get_world_size", return_value=2): + accelerator = Accelerator( + model=mock_model, + samples_per_gpu=8, + grad_accum=2, + train_loader=mock_train_loader, + save_samples=1000, + distributed_framework=DistributedBackend.DEEPSPEED, + deepspeed_cpu_offload_optimizer_ratio=1.0, # Add default value + ) + + # Mock the accelerator's prepare method + accelerator.accelerator = MagicMock() + accelerator.accelerator.prepare.return_value = ( + mock_model.model, + mock_optimizer, + mock_train_loader, + MagicMock(), # lr_scheduler + ) + + accelerator.prepare_with_optimizer( + optimizer=mock_optimizer, + lr_scheduler="cosine", + num_epochs=3, + num_warmup_steps=100, + ) + + # Verify that prepare was called with the correct arguments + accelerator.accelerator.prepare.assert_called_once() + assert accelerator.optimizer == mock_optimizer + + +def test_accelerator_deepspeed_cpu_offload( + mock_model, mock_train_loader, mock_transformers_accel +): + with patch("torch.distributed.get_world_size", return_value=2): + accelerator = Accelerator( + model=mock_model, + samples_per_gpu=8, + grad_accum=2, + train_loader=mock_train_loader, + save_samples=1000, + distributed_framework=DistributedBackend.DEEPSPEED, + deepspeed_cpu_offload_optimizer=True, + deepspeed_cpu_offload_optimizer_pin_memory=True, + deepspeed_cpu_offload_optimizer_ratio=0.5, + ) + + assert accelerator.deepspeed_cpu_offload_optimizer is True + assert accelerator.deepspeed_cpu_offload_optimizer_pin_memory is True + assert accelerator.deepspeed_cpu_offload_optimizer_ratio == 0.5 + + +def test_accelerator_fsdp_cpu_offload( + mock_model, mock_train_loader, mock_transformers_accel +): + with patch("torch.distributed.get_world_size", return_value=2): + accelerator = Accelerator( + model=mock_model, + samples_per_gpu=8, + grad_accum=2, + train_loader=mock_train_loader, + save_samples=1000, + distributed_framework=DistributedBackend.FSDP, + fsdp_sharding_strategy="HYBRID_SHARD", + fsdp_cpu_offload_params=True, + ) + + assert accelerator.fsdp_cpu_offload_params is True + + +def test_accelerator_getattr(mock_model, mock_train_loader, mock_transformers_accel): + with patch("torch.distributed.get_world_size", return_value=2): + accelerator = Accelerator( + model=mock_model, + samples_per_gpu=8, + grad_accum=2, + train_loader=mock_train_loader, + save_samples=1000, + distributed_framework=DistributedBackend.DEEPSPEED, + deepspeed_cpu_offload_optimizer_ratio=1.0, # Add default value + ) + + # Mock a method on the underlying accelerator + mock_method = MagicMock() + accelerator.accelerator = MagicMock() + accelerator.accelerator.some_method = mock_method + + # Test that __getattr__ forwards to the underlying accelerator + result = accelerator.some_method() + assert result == mock_method.return_value + + +def test_accelerator_setup_deepspeed_classmethod( + mock_model, mock_train_loader, mock_transformers_accel +): + with patch("torch.distributed.get_world_size", return_value=2): + accelerator = Accelerator.setup_deepspeed( + model=mock_model, + samples_per_gpu=8, + grad_accum=2, + train_loader=mock_train_loader, + deepspeed_cpu_offload_optimizer=True, + deepspeed_cpu_offload_optimizer_pin_memory=True, + deepspeed_cpu_offload_optimizer_ratio=0.5, + save_samples=1000, + ) + + assert isinstance(accelerator, Accelerator) + assert accelerator.distributed_framework == DistributedBackend.DEEPSPEED + assert accelerator.deepspeed_cpu_offload_optimizer is True + + +def test_accelerator_setup_fsdp_classmethod( + mock_model, mock_train_loader, mock_transformers_accel +): + with patch("torch.distributed.get_world_size", return_value=2): + accelerator = Accelerator.setup_fsdp( + model=mock_model, + samples_per_gpu=8, + grad_accum=2, + train_loader=mock_train_loader, + fsdp_sharding_strategy="HYBRID_SHARD", + fsdp_cpu_offload_params=True, + save_samples=1000, + ) + + assert isinstance(accelerator, Accelerator) + assert accelerator.distributed_framework == DistributedBackend.FSDP + assert accelerator.fsdp_sharding_strategy == "HYBRID_SHARD" + assert accelerator.fsdp_cpu_offload_params is True + + +def test_accelerator_with_lora(mock_model, mock_train_loader, mock_transformers_accel): + # Set up a mock LoRA config + mock_model.lora_config = MagicMock() + mock_model.lora_config.target_modules = ["q_proj", "v_proj"] + + # Mock the fsdp_auto_wrap_policy function + mock_wrap_policy = MagicMock() + with patch("peft.utils.other.fsdp_auto_wrap_policy", return_value=mock_wrap_policy): + with patch("torch.distributed.get_world_size", return_value=2): + accelerator = Accelerator( + model=mock_model, + samples_per_gpu=8, + grad_accum=2, + train_loader=mock_train_loader, + save_samples=1000, + distributed_framework=DistributedBackend.FSDP, + fsdp_sharding_strategy="HYBRID_SHARD", + ) + + # Verify that the accelerator was initialized with LoRA config + assert accelerator.model.lora_config is not None + assert accelerator.model.lora_config.target_modules == ["q_proj", "v_proj"]