diff --git a/docs/source/en/main_classes/deepspeed.md b/docs/source/en/main_classes/deepspeed.md index ae0292e8ee30..201d12895472 100644 --- a/docs/source/en/main_classes/deepspeed.md +++ b/docs/source/en/main_classes/deepspeed.md @@ -2065,20 +2065,20 @@ In this case you usually need to raise the value of `initial_scale_power`. Setti ## Non-Trainer Deepspeed Integration -The [`~deepspeed.HfDeepSpeedConfig`] is used to integrate Deepspeed into the 🤗 Transformers core +The [`~integrations.HfDeepSpeedConfig`] is used to integrate Deepspeed into the 🤗 Transformers core functionality, when [`Trainer`] is not used. The only thing that it does is handling Deepspeed ZeRO-3 param gathering and automatically splitting the model onto multiple gpus during `from_pretrained` call. Everything else you have to do by yourself. When using [`Trainer`] everything is automatically taken care of. When not using [`Trainer`], to efficiently deploy DeepSpeed ZeRO-3, you must instantiate the -[`~deepspeed.HfDeepSpeedConfig`] object before instantiating the model and keep that object alive. +[`~integrations.HfDeepSpeedConfig`] object before instantiating the model and keep that object alive. If you're using Deepspeed ZeRO-1 or ZeRO-2 you don't need to use `HfDeepSpeedConfig` at all. For example for a pretrained model: ```python -from transformers.deepspeed import HfDeepSpeedConfig +from transformers.integrations import HfDeepSpeedConfig from transformers import AutoModel import deepspeed @@ -2092,7 +2092,7 @@ engine = deepspeed.initialize(model=model, config_params=ds_config, ...) or for non-pretrained model: ```python -from transformers.deepspeed import HfDeepSpeedConfig +from transformers.integrations import HfDeepSpeedConfig from transformers import AutoModel, AutoConfig import deepspeed @@ -2108,7 +2108,7 @@ Please note that if you're not using the [`Trainer`] integration, you're complet ## HfDeepSpeedConfig -[[autodoc]] deepspeed.HfDeepSpeedConfig +[[autodoc]] integrations.HfDeepSpeedConfig - all ### Custom DeepSpeed ZeRO Inference @@ -2161,7 +2161,7 @@ Make sure to: from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM -from transformers.deepspeed import HfDeepSpeedConfig +from transformers.integrations import HfDeepSpeedConfig import deepspeed import os import torch diff --git a/examples/research_projects/wav2vec2/test_wav2vec2_deepspeed.py b/examples/research_projects/wav2vec2/test_wav2vec2_deepspeed.py index 0f3e239df6d2..d44145f3e0c1 100644 --- a/examples/research_projects/wav2vec2/test_wav2vec2_deepspeed.py +++ b/examples/research_projects/wav2vec2/test_wav2vec2_deepspeed.py @@ -32,7 +32,7 @@ from parameterized import parameterized # noqa from transformers import TrainingArguments, is_torch_available # noqa -from transformers.deepspeed import is_deepspeed_available # noqa +from transformers.integrations.deepspeed import is_deepspeed_available # noqa from transformers.file_utils import WEIGHTS_NAME # noqa from transformers.testing_utils import ( # noqa CaptureLogger, diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9b95aadffccc..45636b9ffe72 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -94,6 +94,7 @@ "data.metrics": [], "data.processors": [], "debug_utils": [], + "deepspeed": [], "dependency_versions_check": [], "dependency_versions_table": [], "dynamic_module_utils": [], @@ -115,8 +116,6 @@ "is_tensorboard_available", "is_wandb_available", ], - "lib_integrations": [], - "lib_integrations.peft": [], "modelcard": ["ModelCard"], "modeling_tf_pytorch_utils": [ "convert_tf_weight_name_to_pt_weight_name", @@ -745,7 +744,6 @@ "is_vision_available", "logging", ], - "utils.bitsandbytes": [], "utils.quantization_config": ["BitsAndBytesConfig", "GPTQConfig"], } @@ -1002,7 +1000,6 @@ "TextDataset", "TextDatasetForNextSentencePrediction", ] - _import_structure["deepspeed"] = [] _import_structure["generation"].extend( [ "BeamScorer", diff --git a/src/transformers/deepspeed.py b/src/transformers/deepspeed.py index 7af2bedece84..840d9cc2f55a 100644 --- a/src/transformers/deepspeed.py +++ b/src/transformers/deepspeed.py @@ -12,378 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Integration with Deepspeed -""" - -import importlib.util -import weakref -from functools import partialmethod - -from .dependency_versions_check import dep_version_check -from .utils import is_accelerate_available, is_torch_available, logging - - -if is_torch_available(): - import torch - -logger = logging.get_logger(__name__) - - -def is_deepspeed_available(): - return importlib.util.find_spec("deepspeed") is not None - - -if is_accelerate_available() and is_deepspeed_available(): - from accelerate.utils.deepspeed import HfDeepSpeedConfig as DeepSpeedConfig -else: - # Inherits from a dummy `object` if accelerate is not available, so that python succeeds to import this file. - # Deepspeed glue code will never inherit this dummy object as it checks if accelerate is available. - from builtins import object as DeepSpeedConfig - - -class HfDeepSpeedConfig(DeepSpeedConfig): - """ - This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage. - - A `weakref` of this object is stored in the module's globals to be able to access the config from areas where - things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore - it's important that this object remains alive while the program is still running. - - [`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration - with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic - the DeepSpeed configuration is not modified in any way. - - Args: - config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict. - - """ - - def __init__(self, config_file_or_dict): - # set global weakref object - set_hf_deepspeed_config(self) - dep_version_check("accelerate") - dep_version_check("deepspeed") - super().__init__(config_file_or_dict) - - -class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): - """ - The `HfTrainerDeepSpeedConfig` object is meant to be created during `TrainingArguments` object creation and has the - same lifespan as the latter. - """ - - def __init__(self, config_file_or_dict): - super().__init__(config_file_or_dict) - self._dtype = None - self.mismatches = [] - - def dtype(self): - if self._dtype is None: - raise ValueError("trainer_config_process() wasn't called yet to tell dtype") - return self._dtype - - def is_auto(self, ds_key_long): - val = self.get_value(ds_key_long) - if val is None: - return False - else: - return val == "auto" - - def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True): - """ - A utility method that massages the config file and can optionally verify that the values match. - - 1. Replace "auto" values with `TrainingArguments` value. - - 2. If it wasn't "auto" and `must_match` is true, then check that DS config matches Trainer - config values and if mismatched add the entry to `self.mismatched` - will assert during - `trainer_config_finalize` for one or more mismatches. - - """ - config, ds_key = self.find_config_node(ds_key_long) - if config is None: - return - - if config.get(ds_key) == "auto": - config[ds_key] = hf_val - return - - if not must_match: - return - - ds_val = config.get(ds_key) - if ds_val is not None and ds_val != hf_val: - self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}") - - fill_only = partialmethod(fill_match, must_match=False) - - def trainer_config_process(self, args): - """ - Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object - creation. - """ - # DeepSpeed does: - # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps - train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps - self.fill_match( - "train_micro_batch_size_per_gpu", args.per_device_train_batch_size, "per_device_train_batch_size" - ) - self.fill_match("gradient_accumulation_steps", args.gradient_accumulation_steps, "gradient_accumulation_steps") - self.fill_match("train_batch_size", train_batch_size, "train_batch_size (calculated)") - self.fill_match("gradient_clipping", args.max_grad_norm, "max_grad_norm") - - self.fill_match("optimizer.params.lr", args.learning_rate, "learning_rate") - self.fill_match("optimizer.params.betas", [args.adam_beta1, args.adam_beta2], "adam_beta1+adam_beta2") - self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon") - self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay") - - self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg - self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate") - # total_num_steps - will get set in trainer_config_finalize - - # fp16 - if args.fp16 or args.fp16_full_eval: - fp16_backend = "apex" if args.fp16_backend == "apex" else "amp" - else: - fp16_backend = None - - if args.save_on_each_node: - # deepspeed uses shared storage by default. Let's override this setting if save_on_each_node == True - self.config["checkpoint"] = self.config.get("checkpoint", {}) - self.config["checkpoint"]["use_node_local_storage"] = args.save_on_each_node - - # amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set - # any here unless the user did the work - self.fill_match( - "fp16.enabled", - ((args.fp16 or args.fp16_full_eval) and fp16_backend == "amp"), - "fp16|fp16_full_eval+fp16_backend(amp)", - ) - - # apex: delegates amp work to apex (which needs to be available), but it cannot be used with any - # ZeRO features - self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)") - self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level") - - self.fill_match("bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval") - - # deepspeed's default mode is fp16 unless there is a config that says differently - if self.is_true("bf16.enabled"): - self._dtype = torch.bfloat16 - elif self.is_false("fp16.enabled"): - self._dtype = torch.float32 - else: - self._dtype = torch.float16 - - def trainer_config_finalize(self, args, model, num_training_steps): - """ - This stage is run after we have the model and know num_training_steps. - - Now we can complete the configuration process. - """ - # zero - - # deal with config keys that use `auto` value and rely on model's hidden_size - hidden_size_based_keys = [ - "zero_optimization.reduce_bucket_size", - "zero_optimization.stage3_prefetch_bucket_size", - "zero_optimization.stage3_param_persistence_threshold", - ] - hidden_size_auto_keys = [x for x in hidden_size_based_keys if self.is_auto(x)] +Integration with Deepspeed - kept for backward compatiblity, if you plan to make any edit, make sure to modify the file +in `integrations/deepspeed` instead. - if len(hidden_size_auto_keys) > 0: - if hasattr(model.config, "hidden_size"): - hidden_size = model.config.hidden_size - elif hasattr(model.config, "hidden_sizes"): - # if there are many hidden sizes pick the largest one - hidden_size = max(model.config.hidden_sizes) - else: - raise ValueError( - "The model's config file has neither `hidden_size` nor `hidden_sizes` entry, " - "therefore it's not possible to automatically fill out the following `auto` entries " - f"in the DeepSpeed config file: {hidden_size_auto_keys}. You can fix that by replacing " - "`auto` values for these keys with an integer value of your choice." - ) - - self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size) - if self.is_zero3(): - # automatically assign the optimal config values based on model config - self.fill_only("zero_optimization.stage3_prefetch_bucket_size", 0.9 * hidden_size * hidden_size) - self.fill_only("zero_optimization.stage3_param_persistence_threshold", 10 * hidden_size) - - # scheduler - self.fill_match("scheduler.params.total_num_steps", num_training_steps, "num_training_steps (calculated)") - self.fill_match("scheduler.params.warmup_num_steps", args.get_warmup_steps(num_training_steps), "warmup_steps") - - if len(self.mismatches) > 0: - mismatches = "\n".join(self.mismatches) - raise ValueError( - "Please correct the following DeepSpeed config values that mismatch TrainingArguments" - f" values:\n{mismatches}\nThe easiest method is to set these DeepSpeed config values to 'auto'." - ) - - -# keep the config object global to be able to access it anywhere during TrainingArguments life-cycle -_hf_deepspeed_config_weak_ref = None - - -def set_hf_deepspeed_config(hf_deepspeed_config_obj): - # this is a special weakref global object to allow us to get to Deepspeed config from APIs - # that don't have an easy way to get to the Deepspeed config outside of the Trainer domain. - global _hf_deepspeed_config_weak_ref - # will go away automatically when HfDeepSpeedConfig is destroyed (when TrainingArguments is destroyed) - _hf_deepspeed_config_weak_ref = weakref.ref(hf_deepspeed_config_obj) - - -def unset_hf_deepspeed_config(): - # useful for unit tests to ensure the global state doesn't leak - call from `tearDown` method - global _hf_deepspeed_config_weak_ref - _hf_deepspeed_config_weak_ref = None - - -def is_deepspeed_zero3_enabled(): - if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None: - return _hf_deepspeed_config_weak_ref().is_zero3() - else: - return False - - -def deepspeed_config(): - if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None: - return _hf_deepspeed_config_weak_ref().config - else: - return None - - -def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters): - """ - A convenience wrapper that deals with optimizer and lr scheduler configuration. - """ - from accelerate.utils import DummyOptim, DummyScheduler - - config = hf_deepspeed_config.config - - # Optimizer + Scheduler - # Currently supported combos: - # 1. DS scheduler + DS optimizer: Yes - # 2. HF scheduler + HF optimizer: Yes - # 3. DS scheduler + HF optimizer: Yes - # 4. HF scheduler + DS optimizer: No - # - # Unless Offload is enabled in which case it's: - # 1. DS scheduler + DS optimizer: Yes - # 2. HF scheduler + HF optimizer: Mostly* - # 3. DS scheduler + HF optimizer: Mostly* - # 4. HF scheduler + DS optimizer: No - # - # Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB) - - optimizer = None - if "optimizer" in config: - if args.adafactor: - raise ValueError( - "--adafactor was passed, but also found `optimizer` configured in the DeepSpeed config. " - "Only one optimizer can be configured." - ) - optimizer = DummyOptim(params=model_parameters) - else: - if hf_deepspeed_config.is_offload(): - logger.info( - "Detected ZeRO Offload and non-DeepSpeed optimizers: This combination should work as long as the" - " custom optimizer has both CPU and GPU implementation (except LAMB)" - ) - - # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch. - # But trainer uses AdamW by default. - optimizer = trainer.create_optimizer() - # To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer` - config["zero_allow_untested_optimizer"] = True - - lr_scheduler = None - if "scheduler" in config: - lr_scheduler = DummyScheduler(optimizer) - else: - if isinstance(optimizer, DummyOptim): - raise ValueError( - "Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. " - "Please configure a scheduler in the DeepSpeed config." - ) - lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) - - return optimizer, lr_scheduler - - -def deepspeed_init(trainer, num_training_steps, inference=False): - """ - Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args. - - If `resume_from_checkpoint` was passed then an attempt to resume from a previously saved checkpoint will be made. - - Args: - trainer: Trainer object - num_training_steps: per single gpu - resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load - inference: launch in inference mode (no optimizer and no lr scheduler) - - Returns: optimizer, lr_scheduler - - We may use `deepspeed_init` more than once during the life of Trainer, when we do - it's a temp hack based on: - https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until Deepspeed fixes a bug where it - can't resume from a checkpoint after it did some stepping https://github.com/microsoft/DeepSpeed/issues/1612 - - """ - from deepspeed.utils import logger as ds_logger - - model = trainer.model - args = trainer.args - - hf_deepspeed_config = trainer.accelerator.state.deepspeed_plugin.hf_ds_config - - # resume config update - some bits like `model` and `num_training_steps` only become available during train - hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps) - - # set the Deepspeed log level consistent with the Trainer - ds_logger.setLevel(args.get_process_log_level()) - - if inference: - # only Z3 makes sense for the inference - if not hf_deepspeed_config.is_zero3(): - raise ValueError("ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config") - - # in case the training config is re-used for inference - hf_deepspeed_config.del_config_sub_tree("optimizer") - hf_deepspeed_config.del_config_sub_tree("lr_scheduler") - optimizer, lr_scheduler = None, None - model_parameters = None - else: - trainer.optimizer = None # important for when deepspeed_init is used as re-init - model_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) - optimizer, lr_scheduler = deepspeed_optim_sched( - trainer, hf_deepspeed_config, args, num_training_steps, model_parameters - ) - - # keep for quick debug: - # from pprint import pprint; pprint(config) - - return optimizer, lr_scheduler - - -def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path): - # it's possible that the user is trying to resume from model_path, which doesn't necessarily - # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's - # a resume from a checkpoint and not just a local pretrained weight. So we check here if the - # path contains what looks like a deepspeed checkpoint - import glob - - deepspeed_checkpoint_dirs = sorted(glob.glob(f"{checkpoint_path}/global_step*")) - - if len(deepspeed_checkpoint_dirs) > 0: - logger.info(f"Attempting to resume from {checkpoint_path}") - # this magically updates self.optimizer and self.lr_scheduler - load_path, _ = deepspeed_engine.load_checkpoint( - checkpoint_path, load_optimizer_states=True, load_lr_scheduler_states=True - ) - if load_path is None: - raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}") - else: - raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}") +Check: https://github.com/huggingface/transformers/pull/25599 +""" +import warnings + + +warnings.warn( + "transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations", + FutureWarning, +) + +# Backward compatibility imports, to make sure all those objects can be found in integrations/deepspeed +from .integrations.deepspeed import ( # noqa + HfDeepSpeedConfig, + HfTrainerDeepSpeedConfig, + deepspeed_config, + deepspeed_init, + deepspeed_load_checkpoint, + deepspeed_optim_sched, + is_deepspeed_available, + is_deepspeed_zero3_enabled, + set_hf_deepspeed_config, + unset_hf_deepspeed_config, +) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 404943da05ac..5d242fd73ad2 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -24,7 +24,7 @@ import torch.distributed as dist from torch import nn -from ..deepspeed import is_deepspeed_zero3_enabled +from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..models.auto import ( MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py new file mode 100644 index 000000000000..07ef9d6e9012 --- /dev/null +++ b/src/transformers/integrations/__init__.py @@ -0,0 +1,71 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .bitsandbytes import ( + get_keys_to_not_convert, + replace_8bit_linear, + replace_with_bnb_linear, + set_module_8bit_tensor_to_device, + set_module_quantized_tensor_to_device, +) +from .deepspeed import ( + HfDeepSpeedConfig, + HfTrainerDeepSpeedConfig, + deepspeed_config, + deepspeed_init, + deepspeed_load_checkpoint, + deepspeed_optim_sched, + is_deepspeed_available, + is_deepspeed_zero3_enabled, + set_hf_deepspeed_config, + unset_hf_deepspeed_config, +) +from .integration_utils import ( + INTEGRATION_TO_CALLBACK, + AzureMLCallback, + ClearMLCallback, + CodeCarbonCallback, + CometCallback, + DagsHubCallback, + FlyteCallback, + MLflowCallback, + NeptuneCallback, + NeptuneMissingConfiguration, + TensorBoardCallback, + WandbCallback, + get_available_reporting_integrations, + get_reporting_integration_callbacks, + hp_params, + is_azureml_available, + is_clearml_available, + is_codecarbon_available, + is_comet_available, + is_dagshub_available, + is_fairscale_available, + is_flyte_deck_standard_available, + is_flytekit_available, + is_mlflow_available, + is_neptune_available, + is_optuna_available, + is_ray_available, + is_ray_tune_available, + is_sigopt_available, + is_tensorboard_available, + is_wandb_available, + rewrite_logs, + run_hp_search_optuna, + run_hp_search_ray, + run_hp_search_sigopt, + run_hp_search_wandb, +) +from .peft import PeftAdapterMixin diff --git a/src/transformers/integrations/bitsandbytes.py b/src/transformers/integrations/bitsandbytes.py new file mode 100644 index 000000000000..1a8220b1ed7b --- /dev/null +++ b/src/transformers/integrations/bitsandbytes.py @@ -0,0 +1,290 @@ +import importlib.metadata +import warnings +from copy import deepcopy + +from packaging import version + +from ..utils import is_accelerate_available, is_bitsandbytes_available, logging + + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + import torch + import torch.nn as nn + + from ..pytorch_utils import Conv1D + +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import find_tied_parameters + +logger = logging.get_logger(__name__) + + +def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None): + """ + A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing + `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The + function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the + class `Int8Params` from `bitsandbytes`. + + Args: + module (`torch.nn.Module`): + The module in which the tensor we want to move lives. + tensor_name (`str`): + The full name of the parameter/buffer. + device (`int`, `str` or `torch.device`): + The device on which to set the tensor. + value (`torch.Tensor`, *optional*): + The value of the tensor (useful when going from the meta device to any other device). + fp16_statistics (`torch.HalfTensor`, *optional*): + The list of fp16 statistics to set on the module, used for serialization. + """ + # Recurse if needed + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + is_buffer = tensor_name in module._buffers + old_value = getattr(module, tensor_name) + + if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") + + is_4bit = False + is_8bit = False + if is_buffer or not is_bitsandbytes_available(): + is_8bit = False + is_4bit = False + else: + is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit) + is_8bit = isinstance(module._parameters[tensor_name], bnb.nn.Int8Params) + + if is_8bit or is_4bit: + param = module._parameters[tensor_name] + if param.device.type != "cuda": + if value is None: + new_value = old_value.to(device) + elif isinstance(value, torch.Tensor): + new_value = value.to("cpu") + if value.dtype == torch.int8: + is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse( + "0.37.2" + ) + if not is_8bit_serializable: + raise ValueError( + "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + else: + new_value = torch.tensor(value, device="cpu") + + # Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization. + # Since weights are saved in the correct "orientation", we skip transposing when loading. + if issubclass(module.source_cls, Conv1D) and fp16_statistics is None: + new_value = new_value.T + + kwargs = old_value.__dict__ + if is_8bit: + new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device) + elif is_4bit: + new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device) + + module._parameters[tensor_name] = new_value + if fp16_statistics is not None: + setattr(module.weight, "SCB", fp16_statistics.to(device)) + + else: + if value is None: + new_value = old_value.to(device) + elif isinstance(value, torch.Tensor): + new_value = value.to(device) + else: + new_value = torch.tensor(value, device=device) + + if is_buffer: + module._buffers[tensor_name] = new_value + else: + new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad) + module._parameters[tensor_name] = new_value + + +def _replace_with_bnb_linear( + model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False +): + """ + Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if (isinstance(module, nn.Linear) or isinstance(module, Conv1D)) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): + with init_empty_weights(): + if isinstance(module, Conv1D): + in_features, out_features = module.weight.shape + else: + in_features = module.in_features + out_features = module.out_features + + if quantization_config.quantization_method() == "llm_int8": + model._modules[name] = bnb.nn.Linear8bitLt( + in_features, + out_features, + module.bias is not None, + has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, + threshold=quantization_config.llm_int8_threshold, + ) + has_been_replaced = True + else: + if ( + quantization_config.llm_int8_skip_modules is not None + and name in quantization_config.llm_int8_skip_modules + ): + pass + else: + model._modules[name] = bnb.nn.Linear4bit( + in_features, + out_features, + module.bias is not None, + quantization_config.bnb_4bit_compute_dtype, + compress_statistics=quantization_config.bnb_4bit_use_double_quant, + quant_type=quantization_config.bnb_4bit_quant_type, + ) + has_been_replaced = True + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_bnb_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): + """ + A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` + library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8(): + 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA + version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ + bitsandbytes` + + The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should + be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no + CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a + matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 + (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no + predictive degradation is possible for very large models (>=176B parameters). + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`): + Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision + for numerical stability reasons. + current_key_name (`List[`str`]`, *optional*): + An array to track the current key of the recursion. This is used to check whether the current key (part of + it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or + `disk`). + """ + modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert + model, has_been_replaced = _replace_with_bnb_linear( + model, modules_to_not_convert, current_key_name, quantization_config + ) + + if not has_been_replaced: + logger.warning( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model + + +# For backward compatibility +def replace_8bit_linear(*args, **kwargs): + warnings.warn( + "`replace_8bit_linear` will be deprecated in a future version, please use `replace_with_bnb_linear` instead", + FutureWarning, + ) + return replace_with_bnb_linear(*args, **kwargs) + + +# For backward compatiblity +def set_module_8bit_tensor_to_device(*args, **kwargs): + warnings.warn( + "`set_module_8bit_tensor_to_device` will be deprecated in a future version, please use `set_module_quantized_tensor_to_device` instead", + FutureWarning, + ) + return set_module_quantized_tensor_to_device(*args, **kwargs) + + +def get_keys_to_not_convert(model): + r""" + An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules + we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want + to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in + int8. + + Parameters: + model (`torch.nn.Module`): + Input model + """ + # Create a copy of the model and tie the weights, then + # check if it contains tied weights + tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` + tied_model.tie_weights() + + tied_params = find_tied_parameters(tied_model) + # For compatibility with Accelerate < 0.18 + if isinstance(tied_params, dict): + tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys()) + else: + tied_keys = sum(tied_params, []) + has_tied_params = len(tied_keys) > 0 + + # If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision + if not has_tied_params: + output_emb = model.get_output_embeddings() + if output_emb is not None: + list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)] + return list_last_module + + # otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision + list_modules = list(model.named_parameters()) + list_last_module = [list_modules[-1][0]] + # add last module together with tied weights + intersection = set(list_last_module) - set(tied_keys) + list_untouched = list(set(tied_keys)) + list(intersection) + + # remove ".weight" from the keys + names_to_remove = [".weight", ".bias"] + filtered_module_names = [] + for name in list_untouched: + for name_to_remove in names_to_remove: + if name_to_remove in name: + name = name.replace(name_to_remove, "") + filtered_module_names.append(name) + + return filtered_module_names diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py new file mode 100644 index 000000000000..efeccb85c246 --- /dev/null +++ b/src/transformers/integrations/deepspeed.py @@ -0,0 +1,389 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Integration with Deepspeed +""" + +import importlib.util +import weakref +from functools import partialmethod + +from ..dependency_versions_check import dep_version_check +from ..utils import is_accelerate_available, is_torch_available, logging + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +def is_deepspeed_available(): + return importlib.util.find_spec("deepspeed") is not None + + +if is_accelerate_available() and is_deepspeed_available(): + from accelerate.utils.deepspeed import HfDeepSpeedConfig as DeepSpeedConfig +else: + # Inherits from a dummy `object` if accelerate is not available, so that python succeeds to import this file. + # Deepspeed glue code will never inherit this dummy object as it checks if accelerate is available. + from builtins import object as DeepSpeedConfig + + +class HfDeepSpeedConfig(DeepSpeedConfig): + """ + This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage. + + A `weakref` of this object is stored in the module's globals to be able to access the config from areas where + things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore + it's important that this object remains alive while the program is still running. + + [`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration + with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic + the DeepSpeed configuration is not modified in any way. + + Args: + config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict. + + """ + + def __init__(self, config_file_or_dict): + # set global weakref object + set_hf_deepspeed_config(self) + dep_version_check("accelerate") + dep_version_check("deepspeed") + super().__init__(config_file_or_dict) + + +class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): + """ + The `HfTrainerDeepSpeedConfig` object is meant to be created during `TrainingArguments` object creation and has the + same lifespan as the latter. + """ + + def __init__(self, config_file_or_dict): + super().__init__(config_file_or_dict) + self._dtype = None + self.mismatches = [] + + def dtype(self): + if self._dtype is None: + raise ValueError("trainer_config_process() wasn't called yet to tell dtype") + return self._dtype + + def is_auto(self, ds_key_long): + val = self.get_value(ds_key_long) + if val is None: + return False + else: + return val == "auto" + + def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True): + """ + A utility method that massages the config file and can optionally verify that the values match. + + 1. Replace "auto" values with `TrainingArguments` value. + + 2. If it wasn't "auto" and `must_match` is true, then check that DS config matches Trainer + config values and if mismatched add the entry to `self.mismatched` - will assert during + `trainer_config_finalize` for one or more mismatches. + + """ + config, ds_key = self.find_config_node(ds_key_long) + if config is None: + return + + if config.get(ds_key) == "auto": + config[ds_key] = hf_val + return + + if not must_match: + return + + ds_val = config.get(ds_key) + if ds_val is not None and ds_val != hf_val: + self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}") + + fill_only = partialmethod(fill_match, must_match=False) + + def trainer_config_process(self, args): + """ + Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object + creation. + """ + # DeepSpeed does: + # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps + train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps + self.fill_match( + "train_micro_batch_size_per_gpu", args.per_device_train_batch_size, "per_device_train_batch_size" + ) + self.fill_match("gradient_accumulation_steps", args.gradient_accumulation_steps, "gradient_accumulation_steps") + self.fill_match("train_batch_size", train_batch_size, "train_batch_size (calculated)") + self.fill_match("gradient_clipping", args.max_grad_norm, "max_grad_norm") + + self.fill_match("optimizer.params.lr", args.learning_rate, "learning_rate") + self.fill_match("optimizer.params.betas", [args.adam_beta1, args.adam_beta2], "adam_beta1+adam_beta2") + self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon") + self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay") + + self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg + self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate") + # total_num_steps - will get set in trainer_config_finalize + + # fp16 + if args.fp16 or args.fp16_full_eval: + fp16_backend = "apex" if args.fp16_backend == "apex" else "amp" + else: + fp16_backend = None + + if args.save_on_each_node: + # deepspeed uses shared storage by default. Let's override this setting if save_on_each_node == True + self.config["checkpoint"] = self.config.get("checkpoint", {}) + self.config["checkpoint"]["use_node_local_storage"] = args.save_on_each_node + + # amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set + # any here unless the user did the work + self.fill_match( + "fp16.enabled", + ((args.fp16 or args.fp16_full_eval) and fp16_backend == "amp"), + "fp16|fp16_full_eval+fp16_backend(amp)", + ) + + # apex: delegates amp work to apex (which needs to be available), but it cannot be used with any + # ZeRO features + self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)") + self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level") + + self.fill_match("bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval") + + # deepspeed's default mode is fp16 unless there is a config that says differently + if self.is_true("bf16.enabled"): + self._dtype = torch.bfloat16 + elif self.is_false("fp16.enabled"): + self._dtype = torch.float32 + else: + self._dtype = torch.float16 + + def trainer_config_finalize(self, args, model, num_training_steps): + """ + This stage is run after we have the model and know num_training_steps. + + Now we can complete the configuration process. + """ + # zero + + # deal with config keys that use `auto` value and rely on model's hidden_size + hidden_size_based_keys = [ + "zero_optimization.reduce_bucket_size", + "zero_optimization.stage3_prefetch_bucket_size", + "zero_optimization.stage3_param_persistence_threshold", + ] + hidden_size_auto_keys = [x for x in hidden_size_based_keys if self.is_auto(x)] + + if len(hidden_size_auto_keys) > 0: + if hasattr(model.config, "hidden_size"): + hidden_size = model.config.hidden_size + elif hasattr(model.config, "hidden_sizes"): + # if there are many hidden sizes pick the largest one + hidden_size = max(model.config.hidden_sizes) + else: + raise ValueError( + "The model's config file has neither `hidden_size` nor `hidden_sizes` entry, " + "therefore it's not possible to automatically fill out the following `auto` entries " + f"in the DeepSpeed config file: {hidden_size_auto_keys}. You can fix that by replacing " + "`auto` values for these keys with an integer value of your choice." + ) + + self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size) + if self.is_zero3(): + # automatically assign the optimal config values based on model config + self.fill_only("zero_optimization.stage3_prefetch_bucket_size", 0.9 * hidden_size * hidden_size) + self.fill_only("zero_optimization.stage3_param_persistence_threshold", 10 * hidden_size) + + # scheduler + self.fill_match("scheduler.params.total_num_steps", num_training_steps, "num_training_steps (calculated)") + self.fill_match("scheduler.params.warmup_num_steps", args.get_warmup_steps(num_training_steps), "warmup_steps") + + if len(self.mismatches) > 0: + mismatches = "\n".join(self.mismatches) + raise ValueError( + "Please correct the following DeepSpeed config values that mismatch TrainingArguments" + f" values:\n{mismatches}\nThe easiest method is to set these DeepSpeed config values to 'auto'." + ) + + +# keep the config object global to be able to access it anywhere during TrainingArguments life-cycle +_hf_deepspeed_config_weak_ref = None + + +def set_hf_deepspeed_config(hf_deepspeed_config_obj): + # this is a special weakref global object to allow us to get to Deepspeed config from APIs + # that don't have an easy way to get to the Deepspeed config outside of the Trainer domain. + global _hf_deepspeed_config_weak_ref + # will go away automatically when HfDeepSpeedConfig is destroyed (when TrainingArguments is destroyed) + _hf_deepspeed_config_weak_ref = weakref.ref(hf_deepspeed_config_obj) + + +def unset_hf_deepspeed_config(): + # useful for unit tests to ensure the global state doesn't leak - call from `tearDown` method + global _hf_deepspeed_config_weak_ref + _hf_deepspeed_config_weak_ref = None + + +def is_deepspeed_zero3_enabled(): + if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None: + return _hf_deepspeed_config_weak_ref().is_zero3() + else: + return False + + +def deepspeed_config(): + if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None: + return _hf_deepspeed_config_weak_ref().config + else: + return None + + +def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters): + """ + A convenience wrapper that deals with optimizer and lr scheduler configuration. + """ + from accelerate.utils import DummyOptim, DummyScheduler + + config = hf_deepspeed_config.config + + # Optimizer + Scheduler + # Currently supported combos: + # 1. DS scheduler + DS optimizer: Yes + # 2. HF scheduler + HF optimizer: Yes + # 3. DS scheduler + HF optimizer: Yes + # 4. HF scheduler + DS optimizer: No + # + # Unless Offload is enabled in which case it's: + # 1. DS scheduler + DS optimizer: Yes + # 2. HF scheduler + HF optimizer: Mostly* + # 3. DS scheduler + HF optimizer: Mostly* + # 4. HF scheduler + DS optimizer: No + # + # Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB) + + optimizer = None + if "optimizer" in config: + if args.adafactor: + raise ValueError( + "--adafactor was passed, but also found `optimizer` configured in the DeepSpeed config. " + "Only one optimizer can be configured." + ) + optimizer = DummyOptim(params=model_parameters) + else: + if hf_deepspeed_config.is_offload(): + logger.info( + "Detected ZeRO Offload and non-DeepSpeed optimizers: This combination should work as long as the" + " custom optimizer has both CPU and GPU implementation (except LAMB)" + ) + + # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch. + # But trainer uses AdamW by default. + optimizer = trainer.create_optimizer() + # To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer` + config["zero_allow_untested_optimizer"] = True + + lr_scheduler = None + if "scheduler" in config: + lr_scheduler = DummyScheduler(optimizer) + else: + if isinstance(optimizer, DummyOptim): + raise ValueError( + "Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. " + "Please configure a scheduler in the DeepSpeed config." + ) + lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) + + return optimizer, lr_scheduler + + +def deepspeed_init(trainer, num_training_steps, inference=False): + """ + Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args. + + If `resume_from_checkpoint` was passed then an attempt to resume from a previously saved checkpoint will be made. + + Args: + trainer: Trainer object + num_training_steps: per single gpu + resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load + inference: launch in inference mode (no optimizer and no lr scheduler) + + Returns: optimizer, lr_scheduler + + We may use `deepspeed_init` more than once during the life of Trainer, when we do - it's a temp hack based on: + https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until Deepspeed fixes a bug where it + can't resume from a checkpoint after it did some stepping https://github.com/microsoft/DeepSpeed/issues/1612 + + """ + from deepspeed.utils import logger as ds_logger + + model = trainer.model + args = trainer.args + + hf_deepspeed_config = trainer.accelerator.state.deepspeed_plugin.hf_ds_config + + # resume config update - some bits like `model` and `num_training_steps` only become available during train + hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps) + + # set the Deepspeed log level consistent with the Trainer + ds_logger.setLevel(args.get_process_log_level()) + + if inference: + # only Z3 makes sense for the inference + if not hf_deepspeed_config.is_zero3(): + raise ValueError("ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config") + + # in case the training config is re-used for inference + hf_deepspeed_config.del_config_sub_tree("optimizer") + hf_deepspeed_config.del_config_sub_tree("lr_scheduler") + optimizer, lr_scheduler = None, None + model_parameters = None + else: + trainer.optimizer = None # important for when deepspeed_init is used as re-init + model_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) + optimizer, lr_scheduler = deepspeed_optim_sched( + trainer, hf_deepspeed_config, args, num_training_steps, model_parameters + ) + + # keep for quick debug: + # from pprint import pprint; pprint(config) + + return optimizer, lr_scheduler + + +def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path): + # it's possible that the user is trying to resume from model_path, which doesn't necessarily + # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's + # a resume from a checkpoint and not just a local pretrained weight. So we check here if the + # path contains what looks like a deepspeed checkpoint + import glob + + deepspeed_checkpoint_dirs = sorted(glob.glob(f"{checkpoint_path}/global_step*")) + + if len(deepspeed_checkpoint_dirs) > 0: + logger.info(f"Attempting to resume from {checkpoint_path}") + # this magically updates self.optimizer and self.lr_scheduler + load_path, _ = deepspeed_engine.load_checkpoint( + checkpoint_path, load_optimizer_states=True, load_lr_scheduler_states=True + ) + if load_path is None: + raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}") + else: + raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}") diff --git a/src/transformers/integrations.py b/src/transformers/integrations/integration_utils.py similarity index 99% rename from src/transformers/integrations.py rename to src/transformers/integrations/integration_utils.py index c2c37fffdf9f..a88862851410 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations/integration_utils.py @@ -30,8 +30,8 @@ import numpy as np -from . import __version__ as version -from .utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging +from .. import __version__ as version +from ..utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging logger = logging.get_logger(__name__) @@ -68,10 +68,10 @@ except importlib.metadata.PackageNotFoundError: _has_neptune = False -from .trainer_callback import ProgressCallback, TrainerCallback # noqa: E402 -from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402 -from .training_args import ParallelMode # noqa: E402 -from .utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402 +from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402 +from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402 +from ..training_args import ParallelMode # noqa: E402 +from ..utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402 # Integration functions: diff --git a/src/transformers/lib_integrations/peft/peft_mixin.py b/src/transformers/integrations/peft.py similarity index 99% rename from src/transformers/lib_integrations/peft/peft_mixin.py rename to src/transformers/integrations/peft.py index 7a1f7c1f582e..432c0d3c2bb0 100644 --- a/src/transformers/lib_integrations/peft/peft_mixin.py +++ b/src/transformers/integrations/peft.py @@ -14,7 +14,7 @@ import inspect from typing import Optional -from ...utils import ( +from ..utils import ( check_peft_version, find_adapter_config_file, is_accelerate_available, diff --git a/src/transformers/lib_integrations/__init__.py b/src/transformers/lib_integrations/__init__.py deleted file mode 100644 index 0a2b0329f696..000000000000 --- a/src/transformers/lib_integrations/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from .peft import PeftAdapterMixin diff --git a/src/transformers/lib_integrations/peft/__init__.py b/src/transformers/lib_integrations/peft/__init__.py deleted file mode 100644 index a6c1f0afd7e3..000000000000 --- a/src/transformers/lib_integrations/peft/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .peft_mixin import PeftAdapterMixin diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c5925467677f..497ef408d1e3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -35,10 +35,9 @@ from .activations import get_activation from .configuration_utils import PretrainedConfig -from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled from .dynamic_module_utils import custom_object_save from .generation import GenerationConfig, GenerationMixin -from .lib_integrations import PeftAdapterMixin +from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled from .pytorch_utils import ( # noqa: F401 Conv1D, apply_chunking_to_forward, @@ -663,7 +662,7 @@ def _load_state_dict_into_meta_model( # they won't get loaded. if is_quantized: - from .utils.bitsandbytes import set_module_quantized_tensor_to_device + from .integrations import set_module_quantized_tensor_to_device error_msgs = [] @@ -2960,7 +2959,7 @@ def from_pretrained( keep_in_fp32_modules = [] if load_in_8bit or load_in_4bit: - from .utils.bitsandbytes import get_keys_to_not_convert, replace_with_bnb_linear + from .integrations import get_keys_to_not_convert, replace_with_bnb_linear llm_int8_skip_modules = quantization_config.llm_int8_skip_modules load_in_8bit_fp32_cpu_offload = quantization_config.llm_int8_enable_fp32_cpu_offload @@ -3278,7 +3277,7 @@ def _load_pretrained_model( ): is_safetensors = False if is_quantized: - from .utils.bitsandbytes import set_module_quantized_tensor_to_device + from .integrations import set_module_quantized_tensor_to_device if device_map is not None and "disk" in device_map.values(): archive_file = ( diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index a42fb5eb0678..b886c6ad48ce 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index 5b98a8a6079b..eca5ba014e51 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -23,8 +23,8 @@ from torch import nn from ....activations import ACT2FN -from ....deepspeed import is_deepspeed_zero3_enabled from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ....integrations.deepspeed import is_deepspeed_zero3_enabled from ....modeling_outputs import BaseModelOutput, CausalLMOutput from ....modeling_utils import ( PreTrainedModel, diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index a6a884398b2a..f26b5846972d 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -29,7 +29,7 @@ from ...activations import get_activation from ...configuration_utils import PretrainedConfig -from ...deepspeed import is_deepspeed_zero3_enabled +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py index 6e6c9fcb4aea..4fdc37e240b9 100644 --- a/src/transformers/models/esm/modeling_esmfold.py +++ b/src/transformers/models/esm/modeling_esmfold.py @@ -23,7 +23,7 @@ import torch.nn as nn from torch.nn import LayerNorm -from ...deepspeed import is_deepspeed_available +from ...integrations.deepspeed import is_deepspeed_available from ...modeling_outputs import ModelOutput from ...utils import ( ContextManagers, diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index db239437e125..29d78c12695e 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -35,7 +35,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 8228520dfd5e..948530bb6b3f 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -24,7 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import ( diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 5d9dbccffd1a..b6c31518390d 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -23,7 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index cf2bdd5e52e4..53f01328c9e5 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -24,7 +24,7 @@ from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 67b4bf1a0c6c..17364a255b9c 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 6ae717d9a28a..fbc6c4ced27e 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import softmax_backward_data diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index df31075192a2..36229ba95c83 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 9737433089f8..4c6a1ec13daa 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 4c4ab4b90f3b..73906c691208 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 95f465262c47..af74533ad062 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index f4392073b9a4..76ed22f70eb6 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 61a75babf1c0..9cf67a458b46 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index c8c66657792e..85b947d706aa 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -40,7 +40,6 @@ from transformers import logging as transformers_logging -from .deepspeed import is_deepspeed_available from .integrations import ( is_clearml_available, is_fairscale_available, @@ -49,6 +48,7 @@ is_sigopt_available, is_wandb_available, ) +from .integrations.deepspeed import is_deepspeed_available from .utils import ( is_accelerate_available, is_apex_available, diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 85d4fd5a5252..2694cff70afe 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -58,9 +58,9 @@ from .configuration_utils import PretrainedConfig from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .debug_utils import DebugOption, DebugUnderflowOverflow -from .deepspeed import deepspeed_init, deepspeed_load_checkpoint from .dependency_versions_check import dep_version_check from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend +from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint from .modelcard import TrainingSummary from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES @@ -1197,7 +1197,7 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): # Rebuild the deepspeed config to reflect the updated training parameters from accelerate.utils import DeepSpeedPlugin - from transformers.deepspeed import HfTrainerDeepSpeedConfig + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed) self.args.hf_deepspeed_config.trainer_config_process(self.args) @@ -3899,7 +3899,7 @@ def create_accelerator_and_postprocess(self): if self.is_deepspeed_enabled: if getattr(self.args, "hf_deepspeed_config", None) is None: - from transformers.deepspeed import HfTrainerDeepSpeedConfig + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig ds_plugin = self.accelerator.state.deepspeed_plugin diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index b57770f33b29..88e27e3c4dc7 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -35,7 +35,7 @@ from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler from torch.utils.data.distributed import DistributedSampler -from .deepspeed import is_deepspeed_zero3_enabled +from .integrations.deepspeed import is_deepspeed_zero3_enabled from .tokenization_utils_base import BatchEncoding from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 2ebc6036b0bb..569d939fdcf0 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -20,8 +20,8 @@ from torch import nn from torch.utils.data import Dataset -from .deepspeed import is_deepspeed_zero3_enabled from .generation.configuration_utils import GenerationConfig +from .integrations.deepspeed import is_deepspeed_zero3_enabled from .trainer import Trainer from .utils import logging diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 62e3b515bd6a..11f812eaf2fc 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1647,7 +1647,7 @@ def __post_init__(self): # - must be run before the model is created. if not is_accelerate_available(): raise ValueError("--deepspeed requires Accelerate to be installed: `pip install accelerate`.") - from transformers.deepspeed import HfTrainerDeepSpeedConfig + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig # will be used later by the Trainer # note: leave self.deepspeed unmodified in case a user relies on it not to be modified) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 95a180dc5f48..71707cf56599 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -1,291 +1,28 @@ -import importlib.metadata +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import warnings -from copy import deepcopy -from packaging import version -from ..utils import logging -from .import_utils import is_accelerate_available, is_bitsandbytes_available +warnings.warn( + "transformers.utils.bitsandbytes module is deprecated and will be removed in a future version. Please import bitsandbytes modules directly from transformers.integrations", + FutureWarning, +) - -if is_bitsandbytes_available(): - import bitsandbytes as bnb - import torch - import torch.nn as nn - - from ..pytorch_utils import Conv1D - -if is_accelerate_available(): - from accelerate import init_empty_weights - from accelerate.utils import find_tied_parameters - -logger = logging.get_logger(__name__) - - -def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None): - """ - A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing - `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The - function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the - class `Int8Params` from `bitsandbytes`. - - Args: - module (`torch.nn.Module`): - The module in which the tensor we want to move lives. - tensor_name (`str`): - The full name of the parameter/buffer. - device (`int`, `str` or `torch.device`): - The device on which to set the tensor. - value (`torch.Tensor`, *optional*): - The value of the tensor (useful when going from the meta device to any other device). - fp16_statistics (`torch.HalfTensor`, *optional*): - The list of fp16 statistics to set on the module, used for serialization. - """ - # Recurse if needed - if "." in tensor_name: - splits = tensor_name.split(".") - for split in splits[:-1]: - new_module = getattr(module, split) - if new_module is None: - raise ValueError(f"{module} has no attribute {split}.") - module = new_module - tensor_name = splits[-1] - - if tensor_name not in module._parameters and tensor_name not in module._buffers: - raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") - is_buffer = tensor_name in module._buffers - old_value = getattr(module, tensor_name) - - if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: - raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") - - is_4bit = False - is_8bit = False - if is_buffer or not is_bitsandbytes_available(): - is_8bit = False - is_4bit = False - else: - is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit) - is_8bit = isinstance(module._parameters[tensor_name], bnb.nn.Int8Params) - - if is_8bit or is_4bit: - param = module._parameters[tensor_name] - if param.device.type != "cuda": - if value is None: - new_value = old_value.to(device) - elif isinstance(value, torch.Tensor): - new_value = value.to("cpu") - if value.dtype == torch.int8: - is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse( - "0.37.2" - ) - if not is_8bit_serializable: - raise ValueError( - "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " - "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." - ) - else: - new_value = torch.tensor(value, device="cpu") - - # Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization. - # Since weights are saved in the correct "orientation", we skip transposing when loading. - if issubclass(module.source_cls, Conv1D) and fp16_statistics is None: - new_value = new_value.T - - kwargs = old_value.__dict__ - if is_8bit: - new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device) - elif is_4bit: - new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device) - - module._parameters[tensor_name] = new_value - if fp16_statistics is not None: - setattr(module.weight, "SCB", fp16_statistics.to(device)) - - else: - if value is None: - new_value = old_value.to(device) - elif isinstance(value, torch.Tensor): - new_value = value.to(device) - else: - new_value = torch.tensor(value, device=device) - - if is_buffer: - module._buffers[tensor_name] = new_value - else: - new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad) - module._parameters[tensor_name] = new_value - - -def _replace_with_bnb_linear( - model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False -): - """ - Private method that wraps the recursion for module replacement. - - Returns the converted model and a boolean that indicates if the conversion has been successfull or not. - """ - for name, module in model.named_children(): - if current_key_name is None: - current_key_name = [] - current_key_name.append(name) - - if (isinstance(module, nn.Linear) or isinstance(module, Conv1D)) and name not in modules_to_not_convert: - # Check if the current key is not in the `modules_to_not_convert` - if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): - with init_empty_weights(): - if isinstance(module, Conv1D): - in_features, out_features = module.weight.shape - else: - in_features = module.in_features - out_features = module.out_features - - if quantization_config.quantization_method() == "llm_int8": - model._modules[name] = bnb.nn.Linear8bitLt( - in_features, - out_features, - module.bias is not None, - has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, - threshold=quantization_config.llm_int8_threshold, - ) - has_been_replaced = True - else: - if ( - quantization_config.llm_int8_skip_modules is not None - and name in quantization_config.llm_int8_skip_modules - ): - pass - else: - model._modules[name] = bnb.nn.Linear4bit( - in_features, - out_features, - module.bias is not None, - quantization_config.bnb_4bit_compute_dtype, - compress_statistics=quantization_config.bnb_4bit_use_double_quant, - quant_type=quantization_config.bnb_4bit_quant_type, - ) - has_been_replaced = True - # Store the module class in case we need to transpose the weight later - model._modules[name].source_cls = type(module) - # Force requires grad to False to avoid unexpected errors - model._modules[name].requires_grad_(False) - if len(list(module.children())) > 0: - _, has_been_replaced = _replace_with_bnb_linear( - module, - modules_to_not_convert, - current_key_name, - quantization_config, - has_been_replaced=has_been_replaced, - ) - # Remove the last key for recursion - current_key_name.pop(-1) - return model, has_been_replaced - - -def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): - """ - A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` - library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8(): - 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA - version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ - bitsandbytes` - - The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should - be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no - CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a - matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 - (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no - predictive degradation is possible for very large models (>=176B parameters). - - Parameters: - model (`torch.nn.Module`): - Input model or `torch.nn.Module` as the function is run recursively. - modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`): - Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision - for numerical stability reasons. - current_key_name (`List[`str`]`, *optional*): - An array to track the current key of the recursion. This is used to check whether the current key (part of - it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or - `disk`). - """ - modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert - model, has_been_replaced = _replace_with_bnb_linear( - model, modules_to_not_convert, current_key_name, quantization_config - ) - - if not has_been_replaced: - logger.warning( - "You are loading your model in 8bit or 4bit but no linear modules were found in your model." - " Please double check your model architecture, or submit an issue on github if you think this is" - " a bug." - ) - - return model - - -# For backward compatibility -def replace_8bit_linear(*args, **kwargs): - warnings.warn( - "`replace_8bit_linear` will be deprecated in a future version, please use `replace_with_bnb_linear` instead", - FutureWarning, - ) - return replace_with_bnb_linear(*args, **kwargs) - - -# For backward compatiblity -def set_module_8bit_tensor_to_device(*args, **kwargs): - warnings.warn( - "`set_module_8bit_tensor_to_device` will be deprecated in a future version, please use `set_module_quantized_tensor_to_device` instead", - FutureWarning, - ) - return set_module_quantized_tensor_to_device(*args, **kwargs) - - -def get_keys_to_not_convert(model): - r""" - An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules - we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want - to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in - int8. - - Parameters: - model (`torch.nn.Module`): - Input model - """ - # Create a copy of the model and tie the weights, then - # check if it contains tied weights - tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` - tied_model.tie_weights() - - tied_params = find_tied_parameters(tied_model) - # For compatibility with Accelerate < 0.18 - if isinstance(tied_params, dict): - tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys()) - else: - tied_keys = sum(tied_params, []) - has_tied_params = len(tied_keys) > 0 - - # If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision - if not has_tied_params: - output_emb = model.get_output_embeddings() - if output_emb is not None: - list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)] - return list_last_module - - # otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision - list_modules = list(model.named_parameters()) - list_last_module = [list_modules[-1][0]] - # add last module together with tied weights - intersection = set(list_last_module) - set(tied_keys) - list_untouched = list(set(tied_keys)) + list(intersection) - - # remove ".weight" from the keys - names_to_remove = [".weight", ".bias"] - filtered_module_names = [] - for name in list_untouched: - for name_to_remove in names_to_remove: - if name_to_remove in name: - name = name.replace(name_to_remove, "") - filtered_module_names.append(name) - - return filtered_module_names +from ..integrations import ( # noqa + get_keys_to_not_convert, + replace_8bit_linear, + replace_with_bnb_linear, + set_module_8bit_tensor_to_device, + set_module_quantized_tensor_to_device, +) diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index c460bc9c1508..2fa1caf0b5ca 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -27,7 +27,11 @@ import tests.trainer.test_trainer from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa from transformers import AutoModel, TrainingArguments, is_torch_available, logging -from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available, unset_hf_deepspeed_config +from transformers.integrations.deepspeed import ( + HfDeepSpeedConfig, + is_deepspeed_available, + unset_hf_deepspeed_config, +) from transformers.testing_utils import ( CaptureLogger, CaptureStd, @@ -113,7 +117,7 @@ def require_deepspeed_aio(test_case): if is_deepspeed_available(): from deepspeed.utils import logger as deepspeed_logger # noqa from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint - from transformers.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled # noqa + from transformers.integrations.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled # noqa def get_launcher(distributed=False): diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 55e75bb52bd0..3f3a25913ed0 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -131,7 +131,7 @@ def test_get_keys_to_not_convert(self): from accelerate import init_empty_weights from transformers import AutoModelForMaskedLM, Blip2ForConditionalGeneration, MptForCausalLM, OPTForCausalLM - from transformers.utils.bitsandbytes import get_keys_to_not_convert + from transformers.integrations.bitsandbytes import get_keys_to_not_convert model_id = "mosaicml/mpt-7b" config = AutoConfig.from_pretrained( diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index 39f5316ed12c..ce7e150b23e4 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -383,9 +383,11 @@ src/transformers/hyperparameter_search.py src/transformers/image_processing_utils.py src/transformers/image_transforms.py src/transformers/image_utils.py -src/transformers/integrations.py +src/transformers/integrations/bitsandbytes.py +src/transformers/integrations/deepspeed.py +src/transformers/integrations/integration_utils.py +src/transformers/integrations/peft.py src/transformers/keras_callbacks.py -src/transformers/lib_integrations/peft/peft_mixin.py src/transformers/modelcard.py src/transformers/modeling_flax_outputs.py src/transformers/modeling_flax_pytorch_utils.py