Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
already_imported = [mod for mod in critical_modules if mod in sys.modules]

# This check is critical because Unsloth optimizes these libraries by modifying
# their code at import time. If they're imported first, the original (slower,
# their code at import time. If they're imported first, the original (slower,
# more memory-intensive) implementations will be used instead of Unsloth's
# optimized versions, potentially causing OOM errors or slower training.

Expand Down Expand Up @@ -73,6 +73,17 @@ def get_device_type():
pass
DEVICE_TYPE : str = get_device_type()

def get_device_count():
if DEVICE_TYPE == "cuda":
return torch.cuda.device_count()
elif DEVICE_TYPE == "xpu":
return torch.xpu.device_count()
else:
return 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return 0
return 1

pass

DEVICE_COUNT : int = get_device_count()

# Reduce VRAM usage by reducing fragmentation
# And optimize pinning of memory
if DEVICE_TYPE == "cuda" and os.environ.get("UNSLOTH_VLLM_STANDBY", "0")=="0":
Expand Down Expand Up @@ -237,4 +248,4 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
from .trainer import *

# Patch TRL trainers for backwards compatibility
_patch_trl_trainer()
_patch_trl_trainer()
37 changes: 19 additions & 18 deletions unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
next_power_of_2 = triton.next_power_of_2
import functools
from typing import Optional
from unsloth import DEVICE_TYPE
from unsloth import DEVICE_TYPE, DEVICE_COUNT

# torch.cuda.amp.custom_fwd is deprecated >= 2.4
import torch
Expand Down Expand Up @@ -89,18 +89,19 @@ def get_ptr(x: Optional[torch.Tensor]):
get_ptr = bnb.functional.get_ptr


if DEVICE_TYPE == "cuda" and torch.cuda.device_count() > 1:
torch_gpu_device = torch.cuda.device
elif DEVICE_TYPE == "xpu" and torch.xpu.device_count() > 1:
torch_gpu_device = torch.xpu.device
if DEVICE_COUNT > 1:
if DEVICE_TYPE == "cuda":
torch_gpu_device = torch.cuda.device
elif DEVICE_TYPE == "xpu":
torch_gpu_device = torch.xpu.device
else:
from contextlib import nullcontext
def torch_gpu_device(device): return nullcontext()
pass

# INTEL GPU Specific Logic
if DEVICE_TYPE == "xpu":
_gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
_gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
# NVIDIA GPU Default Logic
else:
_gpu_getCurrentRawStream = torch._C._cuda_getCurrentRawStream
Expand All @@ -121,20 +122,20 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
if DEVICE_TYPE == "xpu":
_XPU_STREAMS = {
(index := torch.xpu.device(i).idx) : ctypes.c_void_p(torch._C._xpu_getCurrentRawStream(index))
for i in range(torch.xpu.device_count())
for i in range(DEVICE_COUNT)
}
XPU_STREAMS = [None] * (max(_XPU_STREAMS.keys()) + 1)
WEIGHT_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)
ABSMAX_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)
for k, v in _XPU_STREAMS.items():
for k, v in _XPU_STREAMS.items():
XPU_STREAMS[k] = v
XPU_STREAMS = tuple(XPU_STREAMS)
del _XPU_STREAMS
else:
# NVIDIA GPU Default Logic
_CUDA_STREAMS = {
(index := torch.cuda.device(i).idx) : ctypes.c_void_p(torch._C._cuda_getCurrentRawStream(index))
for i in range(torch.cuda.device_count())
for i in range(DEVICE_COUNT)
}
CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
Expand All @@ -152,16 +153,16 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
# TODO: After adding XPU BNB support, this function should be implemented
def cdequantize_blockwise_fp32(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_fp32 should not be called now.")

def cdequantize_blockwise_fp16_nf4(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_fp16_nf4 should not be called now.")

def cdequantize_blockwise_bf16_nf4(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_bf16_nf4 should not be called now.")

def cgemm_4bit_inference_naive_fp16(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cgemm_4bit_inference_naive_fp16 should not be called now.")

def cgemm_4bit_inference_naive_bf16(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cgemm_4bit_inference_naive_bf16 should not be called now.")
else:
Expand Down Expand Up @@ -193,7 +194,7 @@ def get_lora_parameters(proj):
adapter = getattr(proj, "active_adapters", None)
if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
adapter = adapter[0]

return (
W,
getattr(W, "quant_state", None),
Expand Down Expand Up @@ -232,7 +233,7 @@ def get_lora_parameters_bias(proj):
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
@torch.inference_mode
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
# TODO: After adding XPU BNB support, check this function
# TODO: After adding XPU BNB support, check this function
if quant_state is None: return W
if type(quant_state) is not list:
# New quant_state as a class
Expand Down Expand Up @@ -535,7 +536,7 @@ def fast_gemv(X, W, quant_state, out = None):
device = W.device
device_index = device.index
CUDA_STREAM = CUDA_STREAMS[device_index]

# assert(dtype == X.dtype)
bout = shape[0]

Expand Down Expand Up @@ -669,7 +670,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
lora_A._fast_lora = lora_A.to(dtype)
lora_B._fast_lora = lora_B.to(dtype)
pass

if bsz == 1:
out = out.view(out_dim)
temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
Expand Down Expand Up @@ -709,6 +710,6 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None):
out.addmm_(XA, B.to(dtype), alpha = s)
# out += (X @ A.to(dtype)) @ (s * B.to(dtype))
pass

return out.view(batch, seq_len, -1) if reshape else out
pass
43 changes: 33 additions & 10 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
import re
import warnings, subprocess, re, inspect, psutil, os, math
from unsloth_zoo.utils import Version
from unsloth_zoo import DEVICE_TYPE
from unsloth import DEVICE_TYPE, DEVICE_COUNT

from unsloth_zoo.tokenizer_utils import (
patch_tokenizer as _patch_tokenizer,
Expand Down Expand Up @@ -142,12 +142,6 @@
import logging
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1)

def get_device_num():
if DEVICE_TYPE == "xpu":
return torch.xpu.device_count()
else:
return torch.cuda.device_count()

# Ignore logging messages
class HideLoggingMessage(logging.Filter):
__slots__ = "text",
Expand Down Expand Up @@ -746,8 +740,7 @@ def get_statistics():
pass
pass
try:
devices = get_device_num()
_get_statistics(f"{devices if devices <= 8 else 9}")
_get_statistics(f"{DEVICE_COUNT if DEVICE_COUNT <= 8 else 9}")
except:
pass
if disabled: enable_progress_bars()
Expand All @@ -773,14 +766,44 @@ def get_statistics():
)
exec(BitsAndBytesConfig__init__, globals())

if get_device_num() == 1:
if DEVICE_COUNT == 1:
from accelerate.utils.dataclasses import DistributedType
def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO
import accelerate.state
accelerate.state.PartialState._prepare_backend = _prepare_backend
accelerate.accelerator.Accelerator.distributed_type = lambda *args, **kwargs: DistributedType.NO
pass

# to move multiple tensors to the same device
def move_to_device(target_device, *tensors):
"""
Move multiple tensors to target device if they're not already there.

Args:
target_device: The target device to move tensors to
*tensors: Variable number of tensors to potentially move

Returns:
tuple: The tensors on the target device (same objects if already on device, new if moved)
"""
if isinstance(target_device, int):
target_device = torch.device(target_device)
elif isinstance(target_device, str):
# if string we expect it to be a device name like "cuda:0"
target_device = torch.device(target_device)
elif isinstance(target_device, torch.device):
pass
else:
raise ValueError(f"Invalid target device: {target_device}")
pass
moved_tensors = []
for tensor in tensors:
if tensor.device != target_device:
moved_tensors.append(tensor.to(target_device))
else:
moved_tensors.append(tensor)
return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0]

import transformers.utils.quantization_config
transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__
# =============================================
Expand Down
23 changes: 14 additions & 9 deletions unsloth/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def CohereAttention_fast_forward(
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
Expand Down Expand Up @@ -254,6 +254,7 @@ def CohereAttention_fast_forward_inference(
do_prefill = False,
attention_mask = None,
):

Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
Expand Down Expand Up @@ -281,14 +282,14 @@ def CohereAttention_fast_forward_inference(
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")

# Mistral Nemo 12b has weird dimensions
if attention_size != hidden_size:
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0")
else:
self.temp_O = self.temp_QA[1][:,:,:hidden_size]
pass

self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
Expand Down Expand Up @@ -320,7 +321,7 @@ def CohereAttention_fast_forward_inference(

# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
cos, sin = self.rotary_emb.get_cached(kv_seq_len)
cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
h = self.half_head_dim
Expand All @@ -338,7 +339,7 @@ def CohereAttention_fast_forward_inference(
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)

# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
Expand Down Expand Up @@ -397,7 +398,7 @@ def CohereModel_fast_forward_inference(
position_ids,
attention_mask = None,
):
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
out_weights = tuple(torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT))
input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(self.config.torch_dtype)
Expand All @@ -417,8 +418,12 @@ def CohereModel_fast_forward_inference(

next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
device_index = getattr(decoder_layer, "_per_layer_device_index", 0)
hidden_states, position_ids = move_to_device(
device_index, hidden_states, position_ids
)
residual = hidden_states
hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weights[device_index])
hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
Expand All @@ -435,7 +440,7 @@ def CohereModel_fast_forward_inference(

next_decoder_cache.append(present_key_value)
pass
hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weight)
hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weights[device_index])

return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
Expand Down Expand Up @@ -468,7 +473,7 @@ def pre_patch():
CohereForCausalLM .forward = CausalLM_fast_forward(CohereModel_fast_forward_inference)
PeftModelForCausalLM .forward = PeftModel_fast_forward
fix_prepare_inputs_for_generation(CohereForCausalLM)

import transformers.models.cohere.modeling_cohere
transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = LlamaRotaryEmbedding
return
Expand Down
17 changes: 9 additions & 8 deletions unsloth/models/falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def FalconH1Attention_fast_forward(
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
Expand Down Expand Up @@ -110,12 +110,13 @@ def FalconH1Attention_fast_forward(
# Extend RoPE dynamically to fit in VRA
rotary_emb = self.rotary_emb
rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
device_index = Q.device.index

if position_ids is None:
# Useful for LongRoPE
cos, sin = rotary_emb.get_cached(kv_seq_len)
cos, sin = rotary_emb.get_cached(kv_seq_len, device_index)
else:
cos, sin = rotary_emb(V, seq_len = kv_seq_len)
cos, sin = rotary_emb.get_cached(kv_seq_len, device_index)
Q, K = fast_rope_embedding(Q, K, cos, sin)

if past_key_value is not None:
Expand Down Expand Up @@ -245,14 +246,14 @@ def FalconH1Attention_fast_forward_inference(
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device)
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device)
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)

# Mistral Nemo 12b has weird dimensions
if attention_size != hidden_size:
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
else:
self.temp_O = self.temp_QA[1][:,:,:hidden_size]
pass

self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device)
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
Expand Down Expand Up @@ -280,7 +281,7 @@ def FalconH1Attention_fast_forward_inference(
# Need to do it prior 2 steps before hitting full on short KV cache
# or else error
self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2)
cos, sin = self.rotary_emb.get_cached(kv_seq_len)
cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
h = self.half_head_dim
Expand All @@ -298,7 +299,7 @@ def FalconH1Attention_fast_forward_inference(
RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)

# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
Expand Down Expand Up @@ -580,7 +581,7 @@ def _fast_prepare_inputs_for_generation(
**kwargs,):
# Overwitten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`
empty_past_kv = past_key_values is None

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
Expand Down
Loading