Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def calculate_settings(n : int) -> (int, int,):
# INTEL GPU specific logic
if DEVICE_TYPE == "xpu":
# TODO: Changed here after adding XPU BNB support
HAS_XPU_STREAM = False
HAS_XPU_STREAM = True
def get_ptr(x: Optional[torch.Tensor]):
raise RuntimeError("XPU BNB support is not implemented yet. This function should not be called.")
else:
Expand Down
10 changes: 8 additions & 2 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@
import logging
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1)

def get_device_num():
if DEVICE_TYPE == "xpu":
return torch.xpu.device_count()
else:
return torch.cuda.device_count()

# Ignore logging messages
class HideLoggingMessage(logging.Filter):
__slots__ = "text",
Expand Down Expand Up @@ -738,7 +744,7 @@ def get_statistics():
pass
pass
try:
devices = torch.cuda.device_count()
devices = get_device_num()
_get_statistics(f"{devices if devices <= 8 else 9}")
except:
pass
Expand All @@ -765,7 +771,7 @@ def get_statistics():
)
exec(BitsAndBytesConfig__init__, globals())

if torch.cuda.device_count() == 1:
if get_device_num() == 1:
from accelerate.utils.dataclasses import DistributedType
def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO
import accelerate.state
Expand Down
31 changes: 20 additions & 11 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@
HAS_XFORMERS = xformers is not None
BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None

def clean_gpu_cache():
if DEVICE_TYPE == "xpu":
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()

def original_apply_qkv(self, X):
Q = self.q_proj(X)
Expand Down Expand Up @@ -1752,10 +1757,11 @@ def from_pretrained(
if not is_vLLM_available():
print("Unsloth: vLLM is not installed! Will use Unsloth inference!")
fast_inference = False
major_version, minor_version = torch.cuda.get_device_capability()
if major_version < 7:
print("Unsloth: vLLM does not work on older GPUs - will switch to Unsloth inference!")
fast_inference = False
if DEVICE_TYPE == "cuda":
major_version, minor_version = torch.cuda.get_device_capability()
if major_version < 7:
print("Unsloth: vLLM does not work on older GPUs - will switch to Unsloth inference!")
fast_inference = False
if unsloth_vllm_standby and os.environ.get("UNSLOTH_VLLM_STANDBY", "0") == "0":
raise RuntimeError("Unsloth: `unsloth_vllm_standby` is True, but environment variable `UNSLOTH_VLLM_STANDBY` is not set to 1!")
pass
Expand All @@ -1779,8 +1785,8 @@ def from_pretrained(
num_gpus = torch.xpu.device_count()
gpu_stats_snippet = f"Intel Toolkit: {gpu_version}."

# TODO: After adding vLLM support for XPU, changed this
vllm_version = ""
try: vllm_version = f" vLLM: {importlib_version('vllm')}."
except: vllm_version = ""
else:
raise ValueError(f"Unsloth: Unsupported device type: {DEVICE_TYPE}")

Expand Down Expand Up @@ -2020,7 +2026,10 @@ def from_pretrained(
import gc
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()"""
if DEVICE_TYPE == "xpu":
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()"""

debug_info = debug_info.split('\n')
debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
Expand Down Expand Up @@ -2508,7 +2517,7 @@ def get_peft_model(
# Remove old items to save VRAM
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
clean_gpu_cache()
pass

if train_lm_head:
Expand All @@ -2519,7 +2528,7 @@ def get_peft_model(
# Remove old items to save VRAM
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
clean_gpu_cache()
pass
pass

Expand Down Expand Up @@ -2580,7 +2589,7 @@ def get_peft_model(
# Clear deleted GPU items
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
clean_gpu_cache()
pass

# Patch for fast inference
Expand Down Expand Up @@ -2796,7 +2805,7 @@ def patch_peft_model(
# Clear deleted GPU items
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
clean_gpu_cache()
pass

# Patch for fast inference
Expand Down
4 changes: 3 additions & 1 deletion unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import inspect
from collections import defaultdict
from unsloth_zoo.rl_replacements import RL_REPLACEMENTS
from unsloth import DEVICE_TYPE

RL_EXTRA_ARGS = defaultdict(list)
RL_FUNCTIONS = defaultdict(list)
RL_PRE_ITEMS = defaultdict(list)
Expand Down Expand Up @@ -258,7 +260,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16

os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
with torch.amp.autocast(device_type = DEVICE_TYPE, dtype = self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(
input_ids = input_ids,
Expand Down