Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
already_imported = [mod for mod in critical_modules if mod in sys.modules]

# This check is critical because Unsloth optimizes these libraries by modifying
# their code at import time. If they're imported first, the original (slower,
# their code at import time. If they're imported first, the original (slower,
# more memory-intensive) implementations will be used instead of Unsloth's
# optimized versions, potentially causing OOM errors or slower training.

Expand Down Expand Up @@ -73,6 +73,17 @@ def get_device_type():
pass
DEVICE_TYPE : str = get_device_type()

def get_device_count():
if DEVICE_TYPE == "cuda":
return torch.cuda.device_count()
elif DEVICE_TYPE == "xpu":
return torch.xpu.device_count()
else:
return 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return 0
return 1

pass

DEVICE_COUNT : int = get_device_count()

# Reduce VRAM usage by reducing fragmentation
# And optimize pinning of memory
if DEVICE_TYPE == "cuda" and os.environ.get("UNSLOTH_VLLM_STANDBY", "0")=="0":
Expand Down Expand Up @@ -237,4 +248,4 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
from .trainer import *

# Patch TRL trainers for backwards compatibility
_patch_trl_trainer()
_patch_trl_trainer()
22 changes: 11 additions & 11 deletions unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def torch_gpu_device(device): return nullcontext()

# INTEL GPU Specific Logic
if DEVICE_TYPE == "xpu":
_gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
_gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
# NVIDIA GPU Default Logic
else:
_gpu_getCurrentRawStream = torch._C._cuda_getCurrentRawStream
Expand All @@ -126,7 +126,7 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
XPU_STREAMS = [None] * (max(_XPU_STREAMS.keys()) + 1)
WEIGHT_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)
ABSMAX_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)
for k, v in _XPU_STREAMS.items():
for k, v in _XPU_STREAMS.items():
XPU_STREAMS[k] = v
XPU_STREAMS = tuple(XPU_STREAMS)
del _XPU_STREAMS
Expand All @@ -152,16 +152,16 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
# TODO: After adding XPU BNB support, this function should be implemented
def cdequantize_blockwise_fp32(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_fp32 should not be called now.")

def cdequantize_blockwise_fp16_nf4(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_fp16_nf4 should not be called now.")

def cdequantize_blockwise_bf16_nf4(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_bf16_nf4 should not be called now.")

def cgemm_4bit_inference_naive_fp16(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cgemm_4bit_inference_naive_fp16 should not be called now.")

def cgemm_4bit_inference_naive_bf16(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cgemm_4bit_inference_naive_bf16 should not be called now.")
else:
Expand Down Expand Up @@ -193,7 +193,7 @@ def get_lora_parameters(proj):
adapter = getattr(proj, "active_adapters", None)
if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
adapter = adapter[0]

return (
W,
getattr(W, "quant_state", None),
Expand Down Expand Up @@ -232,7 +232,7 @@ def get_lora_parameters_bias(proj):
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
@torch.inference_mode
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
# TODO: After adding XPU BNB support, check this function
# TODO: After adding XPU BNB support, check this function
if quant_state is None: return W
if type(quant_state) is not list:
# New quant_state as a class
Expand Down Expand Up @@ -535,7 +535,7 @@ def fast_gemv(X, W, quant_state, out = None):
device = W.device
device_index = device.index
CUDA_STREAM = CUDA_STREAMS[device_index]

# assert(dtype == X.dtype)
bout = shape[0]

Expand Down Expand Up @@ -669,7 +669,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
lora_A._fast_lora = lora_A.to(dtype)
lora_B._fast_lora = lora_B.to(dtype)
pass

if bsz == 1:
out = out.view(out_dim)
temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
Expand Down Expand Up @@ -709,6 +709,6 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None):
out.addmm_(XA, B.to(dtype), alpha = s)
# out += (X @ A.to(dtype)) @ (s * B.to(dtype))
pass

return out.view(batch, seq_len, -1) if reshape else out
pass
14 changes: 8 additions & 6 deletions unsloth/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,11 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
self.multi_gpu_cos_cached = [None]*torch.cuda.device_count()
self.multi_gpu_sin_cached = [None]*torch.cuda.device_count()
self.multi_gpu_cos_cached = [None]*DEVICE_COUNT
self.multi_gpu_sin_cached = [None]*DEVICE_COUNT

# Build here to make `torch.jit.trace` work.
for device in range(torch.cuda.device_count()):
for device in range(DEVICE_COUNT):
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device), dtype=torch.get_default_dtype())

# dummy so that patch_utils doesn't fail for now
Expand Down Expand Up @@ -268,9 +268,11 @@ def forward(self, x, position_ids=None, seq_len=None):
if seq_len is not None and seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

device_index = x.device.index

return (
self.multi_gpu_cos_cached[x.device.index][:seq_len].to(dtype=x.dtype),
self.multi_gpu_sin_cached[x.device.index][:seq_len].to(dtype=x.dtype),
self.multi_gpu_cos_cached[device_index][:seq_len].to(dtype=x.dtype),
self.multi_gpu_sin_cached[device_index][:seq_len].to(dtype=x.dtype),
)
pass

Expand All @@ -284,7 +286,7 @@ def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
for device in range(torch.cuda.device_count()):
for device in range(DEVICE_COUNT):
self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device), dtype = x.dtype)
pass
pass
Expand Down
61 changes: 31 additions & 30 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from transformers import __version__ as transformers_version
from unsloth_zoo.utils import Version, _get_dtype
from unsloth_zoo.peft_utils import SKIP_QUANTIZATION_MODULES
from unsloth import DEVICE_TYPE
from unsloth import DEVICE_TYPE, DEVICE_COUNT

transformers_version = Version(transformers_version)
# Transformers moved rotary embeddings out of all attention layers
Expand Down Expand Up @@ -1357,11 +1357,11 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
self.multi_gpu_cos_cached = [None]*torch.cuda.device_count()
self.multi_gpu_sin_cached = [None]*torch.cuda.device_count()
self.multi_gpu_cos_cached = [None]*NUM_GPUS
self.multi_gpu_sin_cached = [None]*NUM_GPUS

# Build here to make `torch.jit.trace` work.
for device_idx in range(torch.cuda.device_count()):
for device_idx in range(NUM_GPUS):
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype())

# dummy so that patch_utils doesn't fail for now
Expand Down Expand Up @@ -1393,23 +1393,24 @@ def forward(self, x, position_ids=None, seq_len=None):
if seq_len is not None and seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

device_index = x.device.index
return (
self.multi_gpu_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype),
self.multi_gpu_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype),
self.multi_gpu_cos_cached[device_index][:seq_len].to(dtype = x.dtype),
self.multi_gpu_sin_cached[device_index][:seq_len].to(dtype = x.dtype),
)
pass

def get_cached(self, seq_len = None, device = None):
if device is None:
device = torch.cuda.current_device()
return self.multi_gpu_cos_cached[device.index if hasattr(device, 'index') else device], self.multi_gpu_sin_cached[device.index if hasattr(device, 'index') else device]
return self.multi_gpu_cos_cached[device.index], self.multi_gpu_sin_cached[device.index]
pass

def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
for device_idx in range(torch.cuda.device_count()):
for device_idx in range(NUM_GPUS):
self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype)
pass
pass
Expand Down Expand Up @@ -1468,8 +1469,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
self.multi_gpu_cos_cached = [None]*torch.cuda.device_count()
self.multi_gpu_sin_cached = [None]*torch.cuda.device_count()
self.multi_gpu_cos_cached = [None]*NUM_GPUS
self.multi_gpu_sin_cached = [None]*NUM_GPUS

# Normal Llama-3 RoPE
inv_freq = 1.0 / (
Expand All @@ -1479,7 +1480,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
self.register_buffer("inv_freq", inv_freq, persistent = False)

# Build here to make `torch.jit.trace` work.
for device_idx in range(torch.cuda.device_count()):
for device_idx in range(NUM_GPUS):
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype())

# dummy so that patch_utils doesn't fail for now
Expand Down Expand Up @@ -1508,24 +1509,24 @@ def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len is not None and seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

device_index = x.device.index
return (
self.multi_gpu_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype),
self.multi_gpu_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype),
self.multi_gpu_cos_cached[device_index][:seq_len].to(dtype = x.dtype),
self.multi_gpu_sin_cached[device_index][:seq_len].to(dtype = x.dtype),
)
pass

def get_cached(self, seq_len = None, device = None):
if device is None:
device = torch.cuda.current_device()
return self.multi_gpu_cos_cached[device.index if hasattr(device, 'index') else device], self.multi_gpu_sin_cached[device.index if hasattr(device, 'index') else device]
return self.multi_gpu_cos_cached[device.index], self.multi_gpu_sin_cached[device.index]
pass

def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
for device_idx in range(torch.cuda.device_count()):
for device_idx in range(NUM_GPUS):
self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype)
pass

Expand Down Expand Up @@ -1589,10 +1590,10 @@ def __init__(self,
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(original_max_position_embeddings, self.max_position_embeddings)
self.multi_gpu_short_cos_cached = [None]*torch.cuda.device_count()
self.multi_gpu_short_sin_cached = [None]*torch.cuda.device_count()
self.multi_gpu_long_cos_cached = [None]*torch.cuda.device_count()
self.multi_gpu_long_sin_cached = [None]*torch.cuda.device_count()
self.multi_gpu_short_cos_cached = [None]*NUM_GPUS
self.multi_gpu_short_sin_cached = [None]*NUM_GPUS
self.multi_gpu_long_cos_cached = [None]*NUM_GPUS
self.multi_gpu_long_sin_cached = [None]*NUM_GPUS

# Long RoPE similar to RoPE except short sequences have 1 cos / sin
# and long sequences have another cos / sin
Expand Down Expand Up @@ -1622,7 +1623,7 @@ def __init__(self,
freqs = torch.outer(t, self.short_inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)

for device_idx in range(torch.cuda.device_count()):
for device_idx in range(NUM_GPUS):
device_obj = torch.device(device_idx)
cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True)
sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True)
Expand Down Expand Up @@ -1657,23 +1658,25 @@ def forward(self, x, position_ids=None, seq_len=None):
if seq_len is not None and seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

device_index = x.device.index

if seq_len is not None and seq_len < self.original_max_position_embeddings:
return (
self.multi_gpu_short_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype),
self.multi_gpu_short_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype),
self.multi_gpu_short_cos_cached[device_index][:seq_len].to(dtype = x.dtype),
self.multi_gpu_short_sin_cached[device_index][:seq_len].to(dtype = x.dtype),
)
else:
return (
self.multi_gpu_long_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype),
self.multi_gpu_long_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype),
self.multi_gpu_long_cos_cached[device_index][:seq_len].to(dtype = x.dtype),
self.multi_gpu_long_sin_cached[device_index][:seq_len].to(dtype = x.dtype),
)
pass
pass

def get_cached(self, seq_len = None, device = None):
if device is None:
device = torch.cuda.current_device()
device_index = device.index if hasattr(device, 'index') else device
device_index = device.index
if seq_len is not None and seq_len < self.original_max_position_embeddings:
return self.multi_gpu_short_cos_cached[device_index], self.multi_gpu_short_sin_cached[device_index]
return self.multi_gpu_long_cos_cached[device_index], self.multi_gpu_long_sin_cached[device_index]
Expand All @@ -1683,7 +1686,7 @@ def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
for device_idx in range(torch.cuda.device_count()):
for device_idx in range(NUM_GPUS):
self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype)
pass
pass
Expand Down Expand Up @@ -1839,15 +1842,13 @@ def from_pretrained(
gpu_stats = torch.cuda.get_device_properties(0)
gpu_version = torch.version.cuda
gpu_stats_snippet = f"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}."
num_gpus = torch.cuda.device_count()

from importlib.metadata import version as importlib_version
try: vllm_version = f" vLLM: {importlib_version('vllm')}."
except: vllm_version = ""
elif DEVICE_TYPE == "xpu":
gpu_stats = torch.xpu.get_device_properties(0)
gpu_version = torch.version.xpu
num_gpus = torch.xpu.device_count()
gpu_stats_snippet = f"Intel Toolkit: {gpu_version}."

try: vllm_version = f" vLLM: {importlib_version('vllm')}."
Expand All @@ -1859,7 +1860,7 @@ def from_pretrained(

statistics = \
f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.{vllm_version}\n"\
f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {num_gpus}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\
f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {DEVICE_COUNT}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\
f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\n"\
f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
f' "-____-" Free license: http://github.com/unslothai/unsloth'
Expand Down
Loading