2727from peft .tuners .lora .model import LoraModel
2828import torch .distributed
2929from transformers import AutoModelForCausalLM , TrainingArguments
30+ from transformers .modeling_utils import is_fsdp_enabled
3031import torch
3132import os
32-
33+ import importlib
3334
3435class AutoGPTQAccelerationPlugin (AccelerationPlugin ):
3536
@@ -48,7 +49,6 @@ def __init__(self, configurations: Dict[str, Dict]):
4849 )
4950
5051 def model_loader (self , model_name : str , ** kwargs ):
51-
5252 # guarded imports
5353 # Third Party
5454 from auto_gptq import AutoGPTQForCausalLM , BaseQuantizeConfig
@@ -80,18 +80,6 @@ def model_loader(self, model_name: str, **kwargs):
8080 low_cpu_mem_usage = kwargs .get ("low_cpu_mem_usage" )
8181 attn_implementation = kwargs .get ("attn_implementation" )
8282
83- if low_cpu_mem_usage :
84- # Note that low_cpu_mem_usage is typically set to transformers.modeling_utils.is_fsdp_enabled.
85- # e.g., https://github.com/huggingface/transformers/blob/a98c41798cf6ed99e1ff17e3792d6e06a2ff2ff3/src/transformers/modeling_utils.py#L2989-L2990
86- # but not doing that now as AutoGPTQ will call make_sure_no_tensor_in_meta_device
87- # https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_base.py#L982C17-L982C51
88- # which does not properly check if a QuantLayer has a bias set or not,
89- # https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_utils.py#L514
90- raise ValueError (
91- "low_cpu_mem_usage set to True. This may raise error if model has no bias, "
92- "due to AutoGPTQ bug. Not supporting at the moment."
93- )
94-
9583 # there are some kwargs that we wont be passed to AutoModel, so we need
9684 # to patch them in
9785 _old_from_config = AutoModelForCausalLM .from_config
@@ -103,12 +91,25 @@ def model_loader(self, model_name: str, **kwargs):
10391 )
10492 AutoModelForCausalLM .from_config = _from_config # patch
10593
94+ if is_fsdp_enabled ():
95+ from .autogptq_utils import make_sure_no_tensor_in_meta_device
96+ source = importlib .import_module ("auto_gptq.modeling._utils" )
97+ original_obj = getattr (source , "make_sure_no_tensor_in_meta_device" )
98+ setattr (source , "make_sure_no_tensor_in_meta_device" , make_sure_no_tensor_in_meta_device )
99+ # reload and this should get the patched object
100+ target_module = importlib .import_module ("auto_gptq.modeling._base" )
101+ importlib .reload (target_module )
102+ low_cpu_mem_usage = True
103+
106104 # NOTE: need to set the device map as below as we want to use AutoGPTQ for training.
107105 # device_map is for inference only https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
108106 # Thus we set it as below to effectively disable it.
109- device_map = (
110- {"" : torch .cuda .current_device ()} if torch .cuda .is_available () else None
111- )
107+ device_map = {
108+ "" : (
109+ torch .cuda .current_device () if not low_cpu_mem_usage
110+ else "cpu"
111+ ) if torch .cuda .is_available () else None
112+ }
112113
113114 # currently only enable triton_v2, because the triton kernels are the only ones
114115 # that have backwards
0 commit comments