Skip to content

Commit e4e32b6

Browse files
committed
workaround low-mem patch
1 parent 79bc89b commit e4e32b6

File tree

2 files changed

+48
-17
lines changed

2 files changed

+48
-17
lines changed

plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,36 @@
2222
import torch
2323

2424

25+
def make_sure_no_tensor_in_meta_device(
26+
model,
27+
use_triton: bool,
28+
desc_act: bool,
29+
group_size: int,
30+
bits: int,
31+
disable_exllama: bool,
32+
disable_exllamav2: bool,
33+
use_marlin: bool = False,
34+
use_tritonv2: bool = False,
35+
):
36+
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear #pylint: disable=import-outside-toplevel,import-error
37+
QuantLinear = dynamically_import_QuantLinear(
38+
use_triton,
39+
desc_act,
40+
group_size,
41+
bits=bits,
42+
disable_exllama=disable_exllama,
43+
disable_exllamav2=disable_exllamav2,
44+
use_marlin=use_marlin,
45+
use_tritonv2=use_tritonv2
46+
)
47+
for n, m in model.named_modules():
48+
bias = getattr(m, "bias", None)
49+
if bias:
50+
if isinstance(m, QuantLinear) and bias.device == torch.device("meta"):
51+
m.register_buffer(
52+
"bias", torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu")
53+
)
54+
2555
def replace_module_peft(self, parent_module, child_name, new_module, old_module):
2656

2757
# replace the lora linear

plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
from peft.tuners.lora.model import LoraModel
2828
import torch.distributed
2929
from transformers import AutoModelForCausalLM, TrainingArguments
30+
from transformers.modeling_utils import is_fsdp_enabled
3031
import torch
3132
import os
32-
33+
import importlib
3334

3435
class AutoGPTQAccelerationPlugin(AccelerationPlugin):
3536

@@ -48,7 +49,6 @@ def __init__(self, configurations: Dict[str, Dict]):
4849
)
4950

5051
def model_loader(self, model_name: str, **kwargs):
51-
5252
# guarded imports
5353
# Third Party
5454
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
@@ -80,18 +80,6 @@ def model_loader(self, model_name: str, **kwargs):
8080
low_cpu_mem_usage = kwargs.get("low_cpu_mem_usage")
8181
attn_implementation = kwargs.get("attn_implementation")
8282

83-
if low_cpu_mem_usage:
84-
# Note that low_cpu_mem_usage is typically set to transformers.modeling_utils.is_fsdp_enabled.
85-
# e.g., https://github.com/huggingface/transformers/blob/a98c41798cf6ed99e1ff17e3792d6e06a2ff2ff3/src/transformers/modeling_utils.py#L2989-L2990
86-
# but not doing that now as AutoGPTQ will call make_sure_no_tensor_in_meta_device
87-
# https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_base.py#L982C17-L982C51
88-
# which does not properly check if a QuantLayer has a bias set or not,
89-
# https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_utils.py#L514
90-
raise ValueError(
91-
"low_cpu_mem_usage set to True. This may raise error if model has no bias, "
92-
"due to AutoGPTQ bug. Not supporting at the moment."
93-
)
94-
9583
# there are some kwargs that we wont be passed to AutoModel, so we need
9684
# to patch them in
9785
_old_from_config = AutoModelForCausalLM.from_config
@@ -103,12 +91,25 @@ def model_loader(self, model_name: str, **kwargs):
10391
)
10492
AutoModelForCausalLM.from_config = _from_config # patch
10593

94+
if is_fsdp_enabled():
95+
from .autogptq_utils import make_sure_no_tensor_in_meta_device
96+
source = importlib.import_module("auto_gptq.modeling._utils")
97+
original_obj = getattr(source, "make_sure_no_tensor_in_meta_device")
98+
setattr(source, "make_sure_no_tensor_in_meta_device", make_sure_no_tensor_in_meta_device)
99+
# reload and this should get the patched object
100+
target_module = importlib.import_module("auto_gptq.modeling._base")
101+
importlib.reload(target_module)
102+
low_cpu_mem_usage = True
103+
106104
# NOTE: need to set the device map as below as we want to use AutoGPTQ for training.
107105
# device_map is for inference only https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
108106
# Thus we set it as below to effectively disable it.
109-
device_map = (
110-
{"": torch.cuda.current_device()} if torch.cuda.is_available() else None
111-
)
107+
device_map = {
108+
"": (
109+
torch.cuda.current_device() if not low_cpu_mem_usage
110+
else "cpu"
111+
) if torch.cuda.is_available() else None
112+
}
112113

113114
# currently only enable triton_v2, because the triton kernels are the only ones
114115
# that have backwards

0 commit comments

Comments
 (0)