Skip to content

Commit 6764755

Browse files
committed
resolve conflicts and define patch function
1 parent b6369c0 commit 6764755

File tree

2 files changed

+61
-26
lines changed

2 files changed

+61
-26
lines changed

plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,23 @@
1616
# https://spdx.dev/learn/handling-license-info/
1717

1818
# Standard
19-
from typing import Callable, List, Any
19+
from typing import Any, Callable, List
2020
import importlib
2121

2222
# Third Party
2323
from peft import LoraConfig
2424
from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ
2525
import torch
2626

27-
def patch_target_module(
27+
28+
# This function will be replaced after merging
29+
# https://github.com/foundation-model-stack/fms-acceleration/pull/25
30+
def _patch_target_module(
2831
to_patch: str,
2932
replace_with: Any,
3033
target_module: str = None,
3134
):
32-
to_patch = to_patch.split('.')
35+
to_patch = to_patch.split(".")
3336
assert len(to_patch) > 1, "must have an object to patch"
3437

3538
to_patch, obj_name_to_patch = to_patch[:-1], to_patch[-1]
@@ -46,6 +49,7 @@ def patch_target_module(
4649
# replace it
4750
setattr(source, obj_name_to_patch, original_obj)
4851

52+
4953
def make_sure_no_tensor_in_meta_device(
5054
model,
5155
use_triton: bool,
@@ -57,7 +61,11 @@ def make_sure_no_tensor_in_meta_device(
5761
use_marlin: bool = False,
5862
use_tritonv2: bool = False,
5963
):
60-
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear #pylint: disable=import-outside-toplevel,import-error
64+
# Third Party
65+
from auto_gptq.utils.import_utils import ( # pylint: disable=import-outside-toplevel,import-error
66+
dynamically_import_QuantLinear,
67+
)
68+
6169
QuantLinear = dynamically_import_QuantLinear(
6270
use_triton,
6371
desc_act,
@@ -66,15 +74,17 @@ def make_sure_no_tensor_in_meta_device(
6674
disable_exllama=disable_exllama,
6775
disable_exllamav2=disable_exllamav2,
6876
use_marlin=use_marlin,
69-
use_tritonv2=use_tritonv2
70-
)
71-
for n, m in model.named_modules():
77+
use_tritonv2=use_tritonv2,
78+
)
79+
for _, m in model.named_modules():
7280
bias = getattr(m, "bias", None)
7381
if bias:
7482
if isinstance(m, QuantLinear) and bias.device == torch.device("meta"):
7583
m.register_buffer(
76-
"bias", torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu")
77-
)
84+
"bias",
85+
torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu"),
86+
)
87+
7888

7989
def replace_module_peft(self, parent_module, child_name, new_module, old_module):
8090

plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import torch
3232
import torch.distributed
3333

34+
3435
class AutoGPTQAccelerationPlugin(AccelerationPlugin):
3536

3637
require_packages = ["auto_gptq"]
@@ -50,11 +51,18 @@ def __init__(self, configurations: Dict[str, Dict]):
5051
def model_loader(self, model_name: str, **kwargs):
5152
# guarded imports
5253
# Third Party
53-
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #pylint: disable=import-outside-toplevel,import-error
54-
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
54+
from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error
55+
AutoGPTQForCausalLM,
56+
BaseQuantizeConfig,
57+
)
58+
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
59+
QuantLinear,
60+
)
5561

5662
# Local
57-
from .autogptq_utils import patch_forward_to_view_attributes_before_call #pylint: disable=import-outside-toplevel
63+
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
64+
patch_forward_to_view_attributes_before_call,
65+
)
5866

5967
# Currently we allow only a quantized checkpoint to be loaded, we do not
6068
# implement the quantization process here.
@@ -93,27 +101,36 @@ def model_loader(self, model_name: str, **kwargs):
93101
AutoModelForCausalLM.from_config = _from_config # patch
94102

95103
if is_fsdp_enabled():
96-
from .autogptq_utils import patch_target_module, make_sure_no_tensor_in_meta_device #pylint: disable=import-outside-toplevel
97-
# We patch `make_sure_no_tensor_in_meta_device` from autogptq to avoid errors on models without bias
98-
patch_target_module(
99-
to_patch = "auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
100-
replace_with = make_sure_no_tensor_in_meta_device,
101-
target_module = "auto_gptq.modeling._base",
104+
# Local
105+
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
106+
_patch_target_module,
107+
make_sure_no_tensor_in_meta_device,
108+
)
109+
110+
# We patch `make_sure_no_tensor_in_meta_device`
111+
# from autogptq to avoid errors on models without bias
112+
_patch_target_module(
113+
to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
114+
replace_with=make_sure_no_tensor_in_meta_device,
115+
target_module="auto_gptq.modeling._base",
102116
)
103117
low_cpu_mem_usage = True
104118

105119
# NOTE: need to set the device map as below as we want to use AutoGPTQ for training.
106-
# device_map is for inference only https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
120+
# device_map is for inference only
121+
# https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
107122
# For low_cpu_mem_usage = True, we have to set the device map to load checkpoints to "cpu"
108123
# to avoid gpu consumption before train
109-
# This approach will divert consumption to cpu memory, a better approach would be to load the checkpoints to meta device
124+
# This approach will divert consumption to cpu memory,
125+
# a better approach would be to load the checkpoints to meta device
110126
# QLoRA is currently implemented by the former approach and will encounter the same issue.
111127
# see https://github.com/huggingface/transformers/pull/25107#issuecomment-2134833262
112128
device_map = {
113129
"": (
114-
torch.cuda.current_device() if not low_cpu_mem_usage
115-
else "cpu"
116-
) if torch.cuda.is_available() else None
130+
(torch.cuda.current_device() if not low_cpu_mem_usage else "cpu")
131+
if torch.cuda.is_available()
132+
else None
133+
)
117134
}
118135

119136
# currently only enable triton_v2, because the triton kernels are the only ones
@@ -202,11 +219,19 @@ def augmentation(
202219
):
203220
# guarded imports
204221
# Third Party
205-
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
206-
from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model #pylint: disable=import-outside-toplevel,import-error
222+
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
223+
QuantLinear,
224+
)
225+
from auto_gptq.utils.peft_utils import ( # pylint: disable=import-outside-toplevel,import-error
226+
GPTQLoraModel,
227+
get_gptq_peft_model,
228+
)
207229

208230
# Local
209-
from .autogptq_utils import create_new_module_peft, replace_module_peft #pylint: disable=import-outside-toplevel
231+
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
232+
create_new_module_peft,
233+
replace_module_peft,
234+
)
210235

211236
(peft_config,) = modifiable_args # unpack modifiable args
212237

0 commit comments

Comments
 (0)