3131import torch
3232import torch .distributed
3333
34+
3435class AutoGPTQAccelerationPlugin (AccelerationPlugin ):
3536
3637 require_packages = ["auto_gptq" ]
@@ -50,11 +51,18 @@ def __init__(self, configurations: Dict[str, Dict]):
5051 def model_loader (self , model_name : str , ** kwargs ):
5152 # guarded imports
5253 # Third Party
53- from auto_gptq import AutoGPTQForCausalLM , BaseQuantizeConfig #pylint: disable=import-outside-toplevel,import-error
54- from auto_gptq .nn_modules .qlinear .qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
54+ from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error
55+ AutoGPTQForCausalLM ,
56+ BaseQuantizeConfig ,
57+ )
58+ from auto_gptq .nn_modules .qlinear .qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
59+ QuantLinear ,
60+ )
5561
5662 # Local
57- from .autogptq_utils import patch_forward_to_view_attributes_before_call #pylint: disable=import-outside-toplevel
63+ from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
64+ patch_forward_to_view_attributes_before_call ,
65+ )
5866
5967 # Currently we allow only a quantized checkpoint to be loaded, we do not
6068 # implement the quantization process here.
@@ -93,27 +101,36 @@ def model_loader(self, model_name: str, **kwargs):
93101 AutoModelForCausalLM .from_config = _from_config # patch
94102
95103 if is_fsdp_enabled ():
96- from .autogptq_utils import patch_target_module , make_sure_no_tensor_in_meta_device #pylint: disable=import-outside-toplevel
97- # We patch `make_sure_no_tensor_in_meta_device` from autogptq to avoid errors on models without bias
98- patch_target_module (
99- to_patch = "auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device" ,
100- replace_with = make_sure_no_tensor_in_meta_device ,
101- target_module = "auto_gptq.modeling._base" ,
104+ # Local
105+ from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
106+ _patch_target_module ,
107+ make_sure_no_tensor_in_meta_device ,
108+ )
109+
110+ # We patch `make_sure_no_tensor_in_meta_device`
111+ # from autogptq to avoid errors on models without bias
112+ _patch_target_module (
113+ to_patch = "auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device" ,
114+ replace_with = make_sure_no_tensor_in_meta_device ,
115+ target_module = "auto_gptq.modeling._base" ,
102116 )
103117 low_cpu_mem_usage = True
104118
105119 # NOTE: need to set the device map as below as we want to use AutoGPTQ for training.
106- # device_map is for inference only https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
120+ # device_map is for inference only
121+ # https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
107122 # For low_cpu_mem_usage = True, we have to set the device map to load checkpoints to "cpu"
108123 # to avoid gpu consumption before train
109- # This approach will divert consumption to cpu memory, a better approach would be to load the checkpoints to meta device
124+ # This approach will divert consumption to cpu memory,
125+ # a better approach would be to load the checkpoints to meta device
110126 # QLoRA is currently implemented by the former approach and will encounter the same issue.
111127 # see https://github.com/huggingface/transformers/pull/25107#issuecomment-2134833262
112128 device_map = {
113129 "" : (
114- torch .cuda .current_device () if not low_cpu_mem_usage
115- else "cpu"
116- ) if torch .cuda .is_available () else None
130+ (torch .cuda .current_device () if not low_cpu_mem_usage else "cpu" )
131+ if torch .cuda .is_available ()
132+ else None
133+ )
117134 }
118135
119136 # currently only enable triton_v2, because the triton kernels are the only ones
@@ -202,11 +219,19 @@ def augmentation(
202219 ):
203220 # guarded imports
204221 # Third Party
205- from auto_gptq .nn_modules .qlinear .qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
206- from auto_gptq .utils .peft_utils import GPTQLoraModel , get_gptq_peft_model #pylint: disable=import-outside-toplevel,import-error
222+ from auto_gptq .nn_modules .qlinear .qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
223+ QuantLinear ,
224+ )
225+ from auto_gptq .utils .peft_utils import ( # pylint: disable=import-outside-toplevel,import-error
226+ GPTQLoraModel ,
227+ get_gptq_peft_model ,
228+ )
207229
208230 # Local
209- from .autogptq_utils import create_new_module_peft , replace_module_peft #pylint: disable=import-outside-toplevel
231+ from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
232+ create_new_module_peft ,
233+ replace_module_peft ,
234+ )
210235
211236 (peft_config ,) = modifiable_args # unpack modifiable args
212237
0 commit comments