Skip to content

Commit 8103238

Browse files
authored
Add MLP & QLoRA Fused Ops and Kernels, Mixtral (#29)
* refactor Signed-off-by: Yu Chin Fabian Lim <[email protected]> * fixes Signed-off-by: Yu Chin Fabian Lim <[email protected]> * refactor mistral Signed-off-by: Yu Chin Fabian Lim <[email protected]> * add mixtral Signed-off-by: Yu Chin Fabian Lim <[email protected]> * some refactoring after introducing mlp Signed-off-by: Yu Chin Fabian Lim <[email protected]> * remove extranous files Signed-off-by: Yu Chin Fabian Lim <[email protected]> * add bnb Signed-off-by: Yu Chin Fabian Lim <[email protected]> * lint + fmt and improvements to readme Signed-off-by: Yu Chin Fabian Lim <[email protected]> * bench fixes * need to handle lora adapters device due to #26 * allow replay of failed benches, addressing comment in #14 * update benches (remove l40) --------- Signed-off-by: Yu Chin Fabian Lim <[email protected]>
1 parent b2b8fe6 commit 8103238

File tree

23 files changed

+626
-326
lines changed

23 files changed

+626
-326
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Plugin | Description | Depends | License | Status
3131
--|--|--|--|--
3232
[framework](./plugins/framework/README.md) | This acceleration framework for integration with huggingface trainers | | | Beta
3333
[accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface<br>AutoGPTQ | Apache 2.0<br>MIT | Beta
34-
[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 with exclusions. | Coming Soon
34+
[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 [(contains extracted code)](./plugins/fused-ops-and-kernels/README.md#code-extracted-from-unsloth)| Beta
3535
MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon
3636

3737
## Usage with FMS HF Tuning
@@ -174,7 +174,6 @@ The benchmarks can be reproduced [with the provided scripts](./scripts/benchmark
174174
175175
See below CSV files for various results:
176176
- [A100-80GB](./scripts/benchmarks/refs/a100_80gb.csv).
177-
- [L40-40GB](./scripts/benchmarks/refs/l40_40gb.csv).
178177
179178
### Code Architecture
180179

plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
# consider making a map if patching more kernels
2929
PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"]
3030

31+
3132
# This function may be moved after merging
3233
# https://github.com/foundation-model-stack/fms-acceleration/pull/25
3334
def _patch_target_module(
@@ -123,6 +124,7 @@ def create_new_module_peft(
123124
# if module cannot be found, return None which results in a raise in the call-stack
124125
return new_module
125126

127+
126128
# consider to move this somewhere more general
127129
def patch_forward_to_view_attributes_before_call(
128130
old_forward: Callable,
@@ -133,9 +135,9 @@ def patch_forward_to_view_attributes_before_call(
133135
):
134136
# patch old_forward to view attribtues to torch_dype
135137
# before call
136-
138+
137139
if submodule_names is None:
138-
submodule_names = ''
140+
submodule_names = ""
139141
if isinstance(submodule_names, str):
140142
submodule_names = [submodule_names]
141143

plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,18 @@ def __init__(self, configurations: Dict[str, Dict]):
5151
def model_loader(self, model_name: str, **kwargs):
5252
# guarded imports
5353
# Third Party
54-
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #pylint: disable=import-outside-toplevel,import-error
55-
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
54+
from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error
55+
AutoGPTQForCausalLM,
56+
BaseQuantizeConfig,
57+
)
58+
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
59+
QuantLinear,
60+
)
5661

5762
# Local
58-
from .autogptq_utils import ( #pylint: disable=import-outside-toplevel
59-
patch_forward_to_view_attributes_before_call,
60-
PATCH_FOR_FSDP_TRITON_V2
63+
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
64+
PATCH_FOR_FSDP_TRITON_V2,
65+
patch_forward_to_view_attributes_before_call,
6166
)
6267

6368
# Currently we allow only a quantized checkpoint to be loaded, we do not
@@ -214,8 +219,14 @@ def augmentation(
214219
):
215220
# guarded imports
216221
# Third Party
217-
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
218-
from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model #pylint: disable=import-outside-toplevel,import-error
222+
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
223+
QuantLinear,
224+
)
225+
from auto_gptq.utils.peft_utils import ( # pylint: disable=import-outside-toplevel,import-error
226+
GPTQLoraModel,
227+
get_gptq_peft_model,
228+
)
229+
219230
# Local
220231
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
221232
create_new_module_peft,

plugins/fused-ops-and-kernels/README.md

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
This library contains fused operations and custom kernels, to be expanded over time. Currently it contains the following:
44

55

6-
1. Fused operations and kernels are extracted from [unsloth](#extracted-code-from-unsloth).
6+
1. Fused operations and kernels extracted from [unsloth](#extracted-code-from-unsloth).
77
- Low-Rank Adapter Fused Operations
88
- Fast RoPE Triton Kernels
99
- Fast RMS LayerNorm Triton Kernels
@@ -13,42 +13,28 @@ This library contains fused operations and custom kernels, to be expanded over t
1313

1414
Plugin | Description | Depends | Loading | Augmentation | Callbacks
1515
--|--|--|--|--|--
16-
[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | Loads fused lora, fast cross-entropy, fast rms, fast RoPE | | | ✅
16+
[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | LoRA fused ops, fast cross-entropy, fast rms, fast RoPE | Contains extracted code | | ✅
1717

1818
### Code Extracted from Unsloth
1919

20-
<!--
21-
NOTE: the
22-
- fused_ops/unsloth_lora -> unsloth main
23-
* utils (fast_dequant, fast_gemv, fast_linear_forward, matmul_lora)
24-
* geglu, swiglu (this can be reused across other models, but currently used inside MLP fused ops only)
25-
* bnb (fast_lora.py)
26-
* gtqp (fast_lora, triton) -> jeromeku
27-
- kernels
28-
* cross_ent, rms, rope -> unsloth main
29-
-->
3020

3121
Notes on the extraction of code from [unsloth](https://github.com/unslothai/unsloth):
32-
- while unsloth is released under Apache 2.0, there are [exceptions to the permissive licenses scattered in the code base](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143).
22+
- While unsloth is [released under Apache 2.0](https://github.com/unslothai/unsloth/blob/main/LICENSE), there are comments indicating some exceptions strewn throughout the code base, see [an example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143).
3323
```
34-
it would require a commercial license if used to run on more than 4 GPUs, see
35-
https://github.com/unslothai/unsloth/blob/d215fd902cf28feb8abcfde2d25281d0fbf9d28c/unsloth/models/llama.py#L1140-L1143
24+
it would require a commercial license if used to run on more than 4 GPUs ...
3625
```
37-
- these exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55), around the model files (namely `llama.py`, `mistral.py`, etc).
38-
* These model files are **not extracted**.
39-
- All code extracted here before the Feb 2024 Release, see table below.
26+
- These exceptions appear to be located around the trainer improvements, see [another example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1177-L1183).
27+
- These exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55); any code that appears in any file where such exceptions occur **is not extracted**.
28+
- Instead in its place, we have adopted a different approach; we adopt the approach of model patching, as opposed unsloths' approach to rewrite the model. Our approach is novel and **completely rewritten from scratch**.
29+
- All extracted code appears before the Feb 2024 Release.
30+
- In the table below we record what was extracted, and the exact commit from which it was taken.
4031
4132
Path | Description | Extracted From | Modifications | Date
4233
--|--|--|--|--
4334
[fused_ops/unsloth_lora](./src/fms_acceleration_foak/fused_ops/unsloth_lora) | QLoRA fast dequant, activation kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024
44-
[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024
35+
[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `fast_lora.py` | 28 Jan 2024
4536
[fused_ops/unsloth_lora/gptq](./src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq) | GPTQ fast dequant (triton_v2) | `jeromeku/main` @ [2839d39](https://github.com/jeromeku/unsloth/commit/2839d390ef3bb318904289bfb9a7751a782c4e44) | `fast_lora.py`<br>`triton/layers.py` | 6 Feb 2024
46-
[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py` | 28 Jan 2024
47-
48-
<!--
49-
[models/](./src/fms_accelerate_unsloth/models/) | Model Forwards | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc)<br><br>`tohrnii/mixtral` @ [a55b7400](https://github.com/tohrnii/unsloth/commit/a55b740062b4fc8ce8f5196bfabe3cf860020ca7) | `llama.py`<br>`mistral.py`<br>`mixtral.py`| `llama.py`<br>`mistral.py`<br>`mixtral.py` | 6 Feb 2024<br><br> 22 Feb 2024
50-
-->
51-
37+
[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py`<br>`rms_layernorm.py` | 28 Jan 2024
5238
5339
## Known Issues
5440

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Callable, Dict, Tuple
1717

1818
# Third Party
19+
from accelerate.utils import set_module_tensor_to_device
1920
from fms_acceleration import AccelerationPlugin
2021
from peft import LoraConfig
2122
from peft.tuners.lora.layer import LoraLayer
@@ -63,9 +64,20 @@ def _all_reduce_hook(grad):
6364
return grad
6465

6566
for mod in modules:
67+
# NOTE: assuming lora has no bias
68+
A = mod.lora_A.default
69+
B = mod.lora_B.default
70+
6671
# install hooks on the adapters
67-
mod.lora_A.default.weight.register_hook(_all_reduce_hook)
68-
mod.lora_B.default.weight.register_hook(_all_reduce_hook)
72+
A.weight.register_hook(_all_reduce_hook)
73+
B.weight.register_hook(_all_reduce_hook)
74+
75+
# because we will ignore these from FSDP, we need to manually
76+
# move them to gpu if they are already not on them
77+
if not A.weight.is_cuda:
78+
set_module_tensor_to_device(A, "weight", "cuda")
79+
if not B.weight.is_cuda:
80+
set_module_tensor_to_device(B, "weight", "cuda")
6981

7082

7183
class FastQuantizedPeftAccelerationPlugin(AccelerationPlugin):
@@ -82,10 +94,7 @@ def __init__(self, configurations: Dict[str, Dict]):
8294

8395
self._base_layer = self._check_config_and_maybe_check_values(
8496
key="peft.quantization.fused_ops_and_kernels.base_layer",
85-
values=[
86-
"auto_gptq",
87-
# "bitsandbytes" # enable later when we have BNB implemented
88-
],
97+
values=["auto_gptq", "bitsandbytes"],
8998
)
9099

91100
# only support these at the moment

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,3 +394,10 @@ def apply_lora_o(self, X):
394394
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
395395
return O
396396
pass
397+
398+
399+
# this will be patchable on the actual module
400+
def apply_lora_o_v2(self, X):
401+
OW, OW_quant, OA, OB, OS = get_lora_parameters(self)
402+
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
403+
return O

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,3 +735,10 @@ def apply_lora_o(self, X):
735735
Oqstate, OA, OB, OS = get_lora_parameters(self.o_proj)
736736
O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS)
737737
return O
738+
739+
740+
# this version can be directly patched on the output linear
741+
def apply_lora_o_v2(self, X):
742+
Oqstate, OA, OB, OS = get_lora_parameters(self)
743+
O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS)
744+
return O

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,9 @@ def backward(ctx, dY):
130130
pass
131131
pass
132132

133-
134-
def fast_rope_embedding(Q, K, cos, sin):
133+
# modified by [email protected]
134+
# NOTE: fast_rope embeddings currently does not account for position ids
135+
def fast_rope_embedding(Q, K, cos, sin, position_ids=None):
135136
Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
136137
K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
137138
return Q, K

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# Local
1616
from .model_patcher import ModelPatcher
1717

18-
PATCHES = [".models.llama", ".models.mistral"]
18+
PATCHES = [".models.llama", ".models.mistral", ".models.mixtral"]
1919
PLUGIN_PREFIX = "fms_acceleration_foak"
2020

2121
# TODO: remove the need for the prefix

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,24 @@
1616
from functools import partial
1717

1818
# Third Party
19-
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm
19+
from transformers.models.llama.modeling_llama import (
20+
LlamaAttention,
21+
LlamaMLP,
22+
LlamaRMSNorm,
23+
)
2024

2125
# Local
2226
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
2327
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
2428
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
25-
from .model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger
26-
from .utils import build_lora_fused_ops, trigger_fused_ops
29+
from .model_patcher import (
30+
ModelPatcher,
31+
ModelPatcherRule,
32+
ModelPatcherTrigger,
33+
combine_functions,
34+
combine_triggers,
35+
)
36+
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
2737

2838
# TODO: have a generic version of this rule
2939
# - do regex on RMSNorm class name
@@ -42,18 +52,54 @@
4252
ModelPatcher.register(
4353
ModelPatcherRule(
4454
rule_id="llama-qkvo",
55+
trigger=combine_triggers(
56+
ModelPatcherTrigger(
57+
check=partial(
58+
trigger_fused_ops,
59+
attn_cls=LlamaAttention,
60+
submodule_names=["q_proj", "k_proj", "v_proj"],
61+
)
62+
),
63+
ModelPatcherTrigger(
64+
check=partial(
65+
trigger_fused_ops,
66+
attn_cls=LlamaAttention,
67+
submodule_names=["o_proj"],
68+
)
69+
),
70+
logic="OR",
71+
),
72+
forward_builder=combine_functions(
73+
partial(
74+
build_lora_fused_ops,
75+
submodule_names=["q_proj", "k_proj", "v_proj"],
76+
fused_op=KEY_QKV,
77+
),
78+
partial(
79+
build_lora_fused_ops,
80+
submodule_names=["o_proj"],
81+
fused_op=KEY_O,
82+
),
83+
logic="APPEND",
84+
),
85+
forward_builder_args=["base_type"],
86+
)
87+
)
88+
89+
ModelPatcher.register(
90+
ModelPatcherRule(
91+
rule_id="llama-mlp",
4592
trigger=ModelPatcherTrigger(
4693
check=partial(
4794
trigger_fused_ops,
48-
attn_cls=LlamaAttention,
49-
qkv_module_names=["q_proj", "k_proj", "v_proj"],
50-
o_module_name="o_proj",
95+
attn_cls=LlamaMLP,
96+
submodule_names=["up_proj", "down_proj", "gate_proj"],
5197
)
5298
),
5399
forward_builder=partial(
54100
build_lora_fused_ops,
55-
qkv_module_names=["q_proj", "k_proj", "v_proj"],
56-
o_module_name="o_proj",
101+
submodule_names=["up_proj", "down_proj", "gate_proj"],
102+
fused_op=KEY_MLP,
57103
),
58104
forward_builder_args=["base_type"],
59105
)

0 commit comments

Comments
 (0)