diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 84dd72e88..0658f3f8f 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -76,6 +76,7 @@ import re import warnings, subprocess, re, inspect, psutil, os, math from unsloth_zoo.utils import Version +from unsloth import DEVICE_TYPE from unsloth_zoo.tokenizer_utils import ( patch_tokenizer as _patch_tokenizer, @@ -316,13 +317,20 @@ def patch_mistral_nemo_config(config): # ============================================= # torch.cuda.amp.custom_fwd is deprecated >= 2.4 torch_version = torch.__version__ -if Version(torch_version) < Version("2.4.0"): - torch_amp_custom_fwd = torch.cuda.amp.custom_fwd - torch_amp_custom_bwd = torch.cuda.amp.custom_bwd -else: - torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda") - torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda") -pass +if DEVICE_TYPE == "cuda": + if Version(torch_version) < Version("2.4.0"): + torch_amp_custom_fwd = torch.cuda.amp.custom_fwd + torch_amp_custom_bwd = torch.cuda.amp.custom_bwd + else: + torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda") + torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda") + pass +elif DEVICE_TYPE == "xpu": + if Version(torch_version) < Version("2.6.0"): + raise RuntimeError("torch.xpu currently only supports torch.version >= 2.6.0") + else: + torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu") + torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu") # ============================================= # ============================================= @@ -363,60 +371,66 @@ def _is_openai_available(): return False # ============================================= # Get Flash Attention v2 if Ampere (RTX 30xx, A100) -import bitsandbytes as bnb +if DEVICE_TYPE == "cuda": + import bitsandbytes as bnb + from transformers import AutoTokenizer from transformers.utils.import_utils import _is_package_available -major_version, minor_version = torch.cuda.get_device_capability() SUPPORTS_BFLOAT16 = False HAS_FLASH_ATTENTION = False HAS_FLASH_ATTENTION_SOFTCAPPING = False -if major_version >= 8: - SUPPORTS_BFLOAT16 = True - if _is_package_available("flash_attn"): - # Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl" - try: +if DEVICE_TYPE == "cuda": + major_version, minor_version = torch.cuda.get_device_capability() + + if major_version >= 8: + SUPPORTS_BFLOAT16 = True + if _is_package_available("flash_attn"): + # Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl" try: - # See https://github.com/unslothai/unsloth/issues/1437 - from flash_attn.flash_attn_interface import flash_attn_gpu + try: + # See https://github.com/unslothai/unsloth/issues/1437 + from flash_attn.flash_attn_interface import flash_attn_gpu + except: + from flash_attn.flash_attn_interface import flash_attn_cuda + HAS_FLASH_ATTENTION = True + + # Also check for softcapping + from flash_attn import __version__ as flash_attn_version + HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version("2.6.3") + if not HAS_FLASH_ATTENTION_SOFTCAPPING: + print( + "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\ + "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\ + "To update flash-attn, do the below:\n"\ + '\npip install --no-deps --upgrade "flash-attn>=2.6.3"' + ) except: - from flash_attn.flash_attn_interface import flash_attn_cuda - HAS_FLASH_ATTENTION = True - - # Also check for softcapping - from flash_attn import __version__ as flash_attn_version - HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version("2.6.3") - if not HAS_FLASH_ATTENTION_SOFTCAPPING: print( - "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\ - "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\ - "To update flash-attn, do the below:\n"\ - '\npip install --no-deps --upgrade "flash-attn>=2.6.3"' + "Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\ + "A possible explanation is you have a new CUDA version which isn't\n"\ + "yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\ + "We shall now use Xformers instead, which does not have any performance hits!\n"\ + "We found this negligible impact by benchmarking on 1x A100." ) - except: - print( - "Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\ - "A possible explanation is you have a new CUDA version which isn't\n"\ - "yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\ - "We shall now use Xformers instead, which does not have any performance hits!\n"\ - "We found this negligible impact by benchmarking on 1x A100." - ) - # Stop Flash Attention from importing! - import transformers.utils.import_utils - transformers.utils.import_utils.is_flash_attn_2_available = lambda *args, **kwargs: False - import transformers.utils - transformers.utils.is_flash_attn_2_available = lambda *args, **kwargs: False + # Stop Flash Attention from importing! + import transformers.utils.import_utils + transformers.utils.import_utils.is_flash_attn_2_available = lambda *args, **kwargs: False + import transformers.utils + transformers.utils.is_flash_attn_2_available = lambda *args, **kwargs: False + HAS_FLASH_ATTENTION = False + pass + else: HAS_FLASH_ATTENTION = False - pass else: + # Tri Dao's benchmark shows xformers is faster for now. HAS_FLASH_ATTENTION = False -else: - # Tri Dao's benchmark shows xformers is faster for now. - HAS_FLASH_ATTENTION = False -pass + pass +elif DEVICE_TYPE == "xpu": + SUPPORTS_BFLOAT16 = True from transformers.models.llama.modeling_llama import logger diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9db8abdd4..c83de8f94 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -24,6 +24,8 @@ from transformers import __version__ as transformers_version from unsloth_zoo.utils import Version, _get_dtype from unsloth_zoo.peft_utils import SKIP_QUANTIZATION_MODULES +from unsloth import DEVICE_TYPE + transformers_version = Version(transformers_version) # Transformers moved rotary embeddings out of all attention layers IS_ATTENTION_REFACTOR = transformers_version > Version("4.47.1") @@ -689,7 +691,7 @@ def LlamaModel_fast_forward( position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype = torch.int32, - device = "cuda:0", + device = f"{DEVICE_TYPE}:0", ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) elif position_ids is not None: @@ -861,13 +863,13 @@ def LlamaModel_fast_forward( is_causal = True, sliding_window = self.config.sliding_window, )\ - .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda",)\ + .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = DEVICE_TYPE,)\ .squeeze(0).squeeze(0) self.GA_mask = AttentionMaskConverter( is_causal = True, )\ - .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda",)\ + .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = DEVICE_TYPE,)\ .squeeze(0).squeeze(0) pass pass @@ -984,11 +986,11 @@ def LlamaModel_fast_forward_inference_custom( bsz, q_len, hd = X.shape assert(q_len == 1) # Get saved buffers to reduce memory movement - residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") - _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") + residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE}:0") + _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE}:0") XX, XX2 = _XX[0], _XX[1] - variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0") - temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") + variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = f"{DEVICE_TYPE}:0") + temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = f"{DEVICE_TYPE}:0") temp_gate, temp_up = temp_mlp[0], temp_mlp[1] seq_len = past_key_values[0][0].shape[-2] @@ -1305,7 +1307,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 dim = getattr(config, "head_dim", None) if dim is None: dim = int((config.hidden_size // config.num_attention_heads)) - device = "cuda" + device = DEVICE_TYPE max_position_embeddings = config.max_position_embeddings pass @@ -1354,7 +1356,7 @@ def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype) + self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype) pass pass @@ -1400,7 +1402,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 dim = int((config.hidden_size // config.num_attention_heads)) - device = "cuda" + device = DEVICE_TYPE max_position_embeddings = config.max_position_embeddings pass @@ -1480,7 +1482,7 @@ def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype) + self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype) pass pass @@ -1507,7 +1509,7 @@ def __init__(self, base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 dim = int((config.hidden_size // config.num_attention_heads)) - device = "cuda" + device = DEVICE_TYPE max_position_embeddings = config.max_position_embeddings pass @@ -1595,7 +1597,7 @@ def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype) + self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype) pass pass @@ -1643,7 +1645,7 @@ def unsloth_fast_generate( kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id) # Mixed precision autocast - with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): + with torch.inference_mode(), torch.autocast(device_type = DEVICE_TYPE, dtype = dtype): output = self._old_generate(*args, **kwargs) pass @@ -1744,19 +1746,36 @@ def from_pretrained( if token is None: token = get_token() if model_patcher is None: model_patcher = FastLlamaModel SUPPORTS_BFLOAT16 = is_bfloat16_supported() - gpu_stats = torch.cuda.get_device_properties(0) - max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) - from importlib.metadata import version as importlib_version - try: vllm_version = f" vLLM: {importlib_version('vllm')}." - except: vllm_version = "" + if DEVICE_TYPE == "cuda": + gpu_stats = torch.cuda.get_device_properties(0) + gpu_version = torch.version.cuda + gpu_stats_snippet = f"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}." + num_gpus = torch.cuda.device_count() + + from importlib.metadata import version as importlib_version + try: vllm_version = f" vLLM: {importlib_version('vllm')}." + except: vllm_version = "" + elif DEVICE_TYPE == "xpu": + gpu_stats = torch.xpu.get_device_properties(0) + gpu_version = torch.version.xpu + num_gpus = torch.xpu.device_count() + gpu_stats_snippet = f"Intel Toolkit: {gpu_version}." + + # TODO: After adding vLLM support for XPU, changed this + vllm_version = "" + else: + raise ValueError(f"Unsloth: Unsupported device type: {DEVICE_TYPE}") + + max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) statistics = \ - f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.{vllm_version}\n"\ - f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {torch.cuda.device_count()}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ - f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ - f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ - f' "-____-" Free license: http://github.com/unslothai/unsloth' + f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.{vllm_version}\n"\ + f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {num_gpus}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ + f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\n"\ + f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ + f' "-____-" Free license: http://github.com/unslothai/unsloth' + print(statistics) # Warn about fast transfers @@ -2215,7 +2234,7 @@ def get_peft_model( pass model.get_input_embeddings().modules_to_save.default\ - .to(device = "cuda", dtype = new_dtype, non_blocking = True) + .to(device = DEVICE_TYPE, dtype = new_dtype, non_blocking = True) model.get_input_embeddings().modules_to_save.default.requires_grad_(True) # [TODO] Move old embed_tokens to CPU - should be disk! @@ -2235,7 +2254,7 @@ def get_peft_model( pass model.get_output_embeddings().modules_to_save.default\ - .to(device = "cuda", dtype = new_dtype, non_blocking = True) + .to(device = DEVICE_TYPE, dtype = new_dtype, non_blocking = True) model.get_output_embeddings().modules_to_save.default.requires_grad_(True) # [TODO] Move old lm_head to CPU - should be disk! @@ -2499,7 +2518,7 @@ def get_peft_model( pass model.get_input_embeddings().modules_to_save.default\ - .to(device = "cuda", dtype = new_dtype, non_blocking = True) + .to(device = DEVICE_TYPE, dtype = new_dtype, non_blocking = True) model.get_input_embeddings().modules_to_save.default.requires_grad_(True) pass @@ -2515,7 +2534,7 @@ def get_peft_model( pass model.get_output_embeddings().modules_to_save.default\ - .to(device = "cuda", dtype = new_dtype, non_blocking = True) + .to(device = DEVICE_TYPE, dtype = new_dtype, non_blocking = True) model.get_output_embeddings().modules_to_save.default.requires_grad_(True) pass diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 4466128a2..a140ccdfe 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -61,6 +61,7 @@ # Old HF Hub versions <= 0.0.25 from huggingface_hub.utils._token import get_token pass +from unsloth import DEVICE_TYPE __all__ = [ "FastBaseModel", @@ -275,12 +276,28 @@ def from_pretrained( pass if token is None: token = get_token() SUPPORTS_BFLOAT16 = is_bfloat16_supported() - gpu_stats = torch.cuda.get_device_properties(0) - max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) - from importlib.metadata import version as importlib_version - try: vllm_version = f" vLLM: {importlib_version('vllm')}." - except: vllm_version = "" + if DEVICE_TYPE == "cuda": + gpu_stats = torch.cuda.get_device_properties(0) + gpu_version = torch.version.cuda + gpu_stats_snippet = f"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}." + num_gpus = torch.cuda.device_count() + + from importlib.metadata import version as importlib_version + try: vllm_version = f" vLLM: {importlib_version('vllm')}." + except: vllm_version = "" + elif DEVICE_TYPE == "xpu": + gpu_stats = torch.xpu.get_device_properties(0) + gpu_version = torch.version.xpu + num_gpus = torch.xpu.device_count() + gpu_stats_snippet = f"Intel Toolkit: {gpu_version}." + + # TODO: After adding vLLM support for XPU, changed this + vllm_version = "" + else: + raise ValueError(f"Unsloth: Unsupported device type: {DEVICE_TYPE}") + + max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) model_type_arch = model_types[0] if model_type_arch == "siglip": @@ -288,11 +305,12 @@ def from_pretrained( if model_type_arch != "siglip": break statistics = \ - f"==((====))== Unsloth {__version__}: Fast {model_type_arch.title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\ - f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {torch.cuda.device_count()}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ - f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ - f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ - f' "-____-" Free license: http://github.com/unslothai/unsloth' + f"==((====))== Unsloth {__version__}: Fast {model_type_arch.title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\ + f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {num_gpus}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ + f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\n"\ + f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ + f' "-____-" Free license: http://github.com/unslothai/unsloth' + print(statistics) # Warn about fast transfers @@ -500,7 +518,10 @@ def from_pretrained( # Clear deleted GPU items for _ in range(3): gc.collect() - torch.cuda.empty_cache() + if DEVICE_TYPE == "cuda": + torch.cuda.empty_cache() + elif DEVICE_TYPE == "xpu": + torch.xpu.empty_cache() pass return model, tokenizer pass @@ -566,7 +587,10 @@ def get_peft_model( # Clear deleted GPU items for _ in range(3): gc.collect() - torch.cuda.empty_cache() + if DEVICE_TYPE == "cuda": + torch.cuda.empty_cache() + elif DEVICE_TYPE == "xpu": + torch.xpu.empty_cache() pass max_seq_length = model.max_seq_length lora_config = LoraConfig( @@ -591,7 +615,10 @@ def get_peft_model( # Clear deleted GPU items for _ in range(3): gc.collect() - torch.cuda.empty_cache() + if DEVICE_TYPE == "cuda": + torch.cuda.empty_cache() + elif DEVICE_TYPE == "xpu": + torch.xpu.empty_cache() pass patch_saving_functions(model, vision = True) @@ -653,7 +680,10 @@ def post_patch_model( # Clear deleted GPU items for _ in range(3): gc.collect() - torch.cuda.empty_cache() + if DEVICE_TYPE == "cuda": + torch.cuda.empty_cache() + elif DEVICE_TYPE == "xpu": + torch.xpu.empty_cache() pass # Add for_inference and for_training model.for_training = functools.partial(FastBaseModel.for_training, model)