Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 59 additions & 45 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import re
import warnings, subprocess, re, inspect, psutil, os, math
from unsloth_zoo.utils import Version
from unsloth import DEVICE_TYPE

from unsloth_zoo.tokenizer_utils import (
patch_tokenizer as _patch_tokenizer,
Expand Down Expand Up @@ -316,13 +317,20 @@ def patch_mistral_nemo_config(config):
# =============================================
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
torch_version = torch.__version__
if Version(torch_version) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
pass
if DEVICE_TYPE == "cuda":
if Version(torch_version) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
pass
elif DEVICE_TYPE == "xpu":
if Version(torch_version) < Version("2.6.0"):
raise RuntimeError("torch.xpu currently only supports torch.version >= 2.6.0")
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu")
# =============================================

# =============================================
Expand Down Expand Up @@ -363,60 +371,66 @@ def _is_openai_available(): return False

# =============================================
# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
import bitsandbytes as bnb
if DEVICE_TYPE == "cuda":
import bitsandbytes as bnb

from transformers import AutoTokenizer
from transformers.utils.import_utils import _is_package_available

major_version, minor_version = torch.cuda.get_device_capability()
SUPPORTS_BFLOAT16 = False
HAS_FLASH_ATTENTION = False
HAS_FLASH_ATTENTION_SOFTCAPPING = False

if major_version >= 8:
SUPPORTS_BFLOAT16 = True
if _is_package_available("flash_attn"):
# Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
try:
if DEVICE_TYPE == "cuda":
major_version, minor_version = torch.cuda.get_device_capability()

if major_version >= 8:
SUPPORTS_BFLOAT16 = True
if _is_package_available("flash_attn"):
# Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
try:
# See https://github.com/unslothai/unsloth/issues/1437
from flash_attn.flash_attn_interface import flash_attn_gpu
try:
# See https://github.com/unslothai/unsloth/issues/1437
from flash_attn.flash_attn_interface import flash_attn_gpu
except:
from flash_attn.flash_attn_interface import flash_attn_cuda
HAS_FLASH_ATTENTION = True

# Also check for softcapping
from flash_attn import __version__ as flash_attn_version
HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version("2.6.3")
if not HAS_FLASH_ATTENTION_SOFTCAPPING:
print(
"Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\
"Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\
"To update flash-attn, do the below:\n"\
'\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
)
except:
from flash_attn.flash_attn_interface import flash_attn_cuda
HAS_FLASH_ATTENTION = True

# Also check for softcapping
from flash_attn import __version__ as flash_attn_version
HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version("2.6.3")
if not HAS_FLASH_ATTENTION_SOFTCAPPING:
print(
"Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\
"Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\
"To update flash-attn, do the below:\n"\
'\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
"Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\
"A possible explanation is you have a new CUDA version which isn't\n"\
"yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\
"We shall now use Xformers instead, which does not have any performance hits!\n"\
"We found this negligible impact by benchmarking on 1x A100."
)
except:
print(
"Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\
"A possible explanation is you have a new CUDA version which isn't\n"\
"yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\
"We shall now use Xformers instead, which does not have any performance hits!\n"\
"We found this negligible impact by benchmarking on 1x A100."
)

# Stop Flash Attention from importing!
import transformers.utils.import_utils
transformers.utils.import_utils.is_flash_attn_2_available = lambda *args, **kwargs: False
import transformers.utils
transformers.utils.is_flash_attn_2_available = lambda *args, **kwargs: False
# Stop Flash Attention from importing!
import transformers.utils.import_utils
transformers.utils.import_utils.is_flash_attn_2_available = lambda *args, **kwargs: False
import transformers.utils
transformers.utils.is_flash_attn_2_available = lambda *args, **kwargs: False

HAS_FLASH_ATTENTION = False
pass
else:
HAS_FLASH_ATTENTION = False
pass
else:
# Tri Dao's benchmark shows xformers is faster for now.
HAS_FLASH_ATTENTION = False
else:
# Tri Dao's benchmark shows xformers is faster for now.
HAS_FLASH_ATTENTION = False
pass
pass
elif DEVICE_TYPE == "xpu":
SUPPORTS_BFLOAT16 = True

from transformers.models.llama.modeling_llama import logger

Expand Down
75 changes: 47 additions & 28 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from transformers import __version__ as transformers_version
from unsloth_zoo.utils import Version, _get_dtype
from unsloth_zoo.peft_utils import SKIP_QUANTIZATION_MODULES
from unsloth import DEVICE_TYPE

transformers_version = Version(transformers_version)
# Transformers moved rotary embeddings out of all attention layers
IS_ATTENTION_REFACTOR = transformers_version > Version("4.47.1")
Expand Down Expand Up @@ -689,7 +691,7 @@ def LlamaModel_fast_forward(
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length,
dtype = torch.int32,
device = "cuda:0",
device = f"{DEVICE_TYPE}:0",
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
elif position_ids is not None:
Expand Down Expand Up @@ -861,13 +863,13 @@ def LlamaModel_fast_forward(
is_causal = True,
sliding_window = self.config.sliding_window,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda",)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = DEVICE_TYPE,)\
.squeeze(0).squeeze(0)

self.GA_mask = AttentionMaskConverter(
is_causal = True,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda",)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = DEVICE_TYPE,)\
.squeeze(0).squeeze(0)
pass
pass
Expand Down Expand Up @@ -984,11 +986,11 @@ def LlamaModel_fast_forward_inference_custom(
bsz, q_len, hd = X.shape
assert(q_len == 1)
# Get saved buffers to reduce memory movement
residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = "cuda:0")
_XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0")
residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE}:0")
_XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE}:0")
XX, XX2 = _XX[0], _XX[1]
variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0")
temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = f"{DEVICE_TYPE}:0")
temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = f"{DEVICE_TYPE}:0")
temp_gate, temp_up = temp_mlp[0], temp_mlp[1]

seq_len = past_key_values[0][0].shape[-2]
Expand Down Expand Up @@ -1305,7 +1307,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = getattr(config, "head_dim", None)
if dim is None: dim = int((config.hidden_size // config.num_attention_heads))
device = "cuda"
device = DEVICE_TYPE
max_position_embeddings = config.max_position_embeddings
pass

Expand Down Expand Up @@ -1354,7 +1356,7 @@ def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype)
self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype)
pass
pass

Expand Down Expand Up @@ -1400,7 +1402,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads))
device = "cuda"
device = DEVICE_TYPE
max_position_embeddings = config.max_position_embeddings
pass

Expand Down Expand Up @@ -1480,7 +1482,7 @@ def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype)
self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype)
pass
pass

Expand All @@ -1507,7 +1509,7 @@ def __init__(self,
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads))
device = "cuda"
device = DEVICE_TYPE
max_position_embeddings = config.max_position_embeddings
pass

Expand Down Expand Up @@ -1595,7 +1597,7 @@ def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype)
self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype)
pass
pass

Expand Down Expand Up @@ -1643,7 +1645,7 @@ def unsloth_fast_generate(
kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id)

# Mixed precision autocast
with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype):
with torch.inference_mode(), torch.autocast(device_type = DEVICE_TYPE, dtype = dtype):
output = self._old_generate(*args, **kwargs)
pass

Expand Down Expand Up @@ -1744,19 +1746,36 @@ def from_pretrained(
if token is None: token = get_token()
if model_patcher is None: model_patcher = FastLlamaModel
SUPPORTS_BFLOAT16 = is_bfloat16_supported()
gpu_stats = torch.cuda.get_device_properties(0)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)

from importlib.metadata import version as importlib_version
try: vllm_version = f" vLLM: {importlib_version('vllm')}."
except: vllm_version = ""
if DEVICE_TYPE == "cuda":
gpu_stats = torch.cuda.get_device_properties(0)
gpu_version = torch.version.cuda
gpu_stats_snippet = f"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}."
num_gpus = torch.cuda.device_count()

from importlib.metadata import version as importlib_version
try: vllm_version = f" vLLM: {importlib_version('vllm')}."
except: vllm_version = ""
elif DEVICE_TYPE == "xpu":
gpu_stats = torch.xpu.get_device_properties(0)
gpu_version = torch.version.xpu
num_gpus = torch.xpu.device_count()
gpu_stats_snippet = f"Intel Toolkit: {gpu_version}."

# TODO: After adding vLLM support for XPU, changed this
vllm_version = ""
else:
raise ValueError(f"Unsloth: Unsupported device type: {DEVICE_TYPE}")

max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)

statistics = \
f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.{vllm_version}\n"\
f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {torch.cuda.device_count()}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\
f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\
f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
f' "-____-" Free license: http://github.com/unslothai/unsloth'
f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.{vllm_version}\n"\
f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {num_gpus}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\
f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\n"\
f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
f' "-____-" Free license: http://github.com/unslothai/unsloth'

print(statistics)

# Warn about fast transfers
Expand Down Expand Up @@ -2215,7 +2234,7 @@ def get_peft_model(
pass

model.get_input_embeddings().modules_to_save.default\
.to(device = "cuda", dtype = new_dtype, non_blocking = True)
.to(device = DEVICE_TYPE, dtype = new_dtype, non_blocking = True)
model.get_input_embeddings().modules_to_save.default.requires_grad_(True)

# [TODO] Move old embed_tokens to CPU - should be disk!
Expand All @@ -2235,7 +2254,7 @@ def get_peft_model(
pass

model.get_output_embeddings().modules_to_save.default\
.to(device = "cuda", dtype = new_dtype, non_blocking = True)
.to(device = DEVICE_TYPE, dtype = new_dtype, non_blocking = True)
model.get_output_embeddings().modules_to_save.default.requires_grad_(True)

# [TODO] Move old lm_head to CPU - should be disk!
Expand Down Expand Up @@ -2499,7 +2518,7 @@ def get_peft_model(
pass

model.get_input_embeddings().modules_to_save.default\
.to(device = "cuda", dtype = new_dtype, non_blocking = True)
.to(device = DEVICE_TYPE, dtype = new_dtype, non_blocking = True)
model.get_input_embeddings().modules_to_save.default.requires_grad_(True)
pass

Expand All @@ -2515,7 +2534,7 @@ def get_peft_model(
pass

model.get_output_embeddings().modules_to_save.default\
.to(device = "cuda", dtype = new_dtype, non_blocking = True)
.to(device = DEVICE_TYPE, dtype = new_dtype, non_blocking = True)
model.get_output_embeddings().modules_to_save.default.requires_grad_(True)
pass

Expand Down
Loading