Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
already_imported = [mod for mod in critical_modules if mod in sys.modules]

# This check is critical because Unsloth optimizes these libraries by modifying
# their code at import time. If they're imported first, the original (slower,
# their code at import time. If they're imported first, the original (slower,
# more memory-intensive) implementations will be used instead of Unsloth's
# optimized versions, potentially causing OOM errors or slower training.

Expand Down Expand Up @@ -73,6 +73,17 @@ def get_device_type():
pass
DEVICE_TYPE : str = get_device_type()

def get_device_count():
if DEVICE_TYPE == "cuda":
return torch.cuda.device_count()
elif DEVICE_TYPE == "xpu":
return torch.xpu.device_count()
else:
return 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return 0
return 1

pass

DEVICE_COUNT : int = get_device_count()

# Reduce VRAM usage by reducing fragmentation
# And optimize pinning of memory
if DEVICE_TYPE == "cuda" and os.environ.get("UNSLOTH_VLLM_STANDBY", "0")=="0":
Expand Down Expand Up @@ -237,4 +248,4 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
from .trainer import *

# Patch TRL trainers for backwards compatibility
_patch_trl_trainer()
_patch_trl_trainer()
22 changes: 11 additions & 11 deletions unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def torch_gpu_device(device): return nullcontext()

# INTEL GPU Specific Logic
if DEVICE_TYPE == "xpu":
_gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
_gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
# NVIDIA GPU Default Logic
else:
_gpu_getCurrentRawStream = torch._C._cuda_getCurrentRawStream
Expand All @@ -126,7 +126,7 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
XPU_STREAMS = [None] * (max(_XPU_STREAMS.keys()) + 1)
WEIGHT_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)
ABSMAX_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)
for k, v in _XPU_STREAMS.items():
for k, v in _XPU_STREAMS.items():
XPU_STREAMS[k] = v
XPU_STREAMS = tuple(XPU_STREAMS)
del _XPU_STREAMS
Expand All @@ -152,16 +152,16 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
# TODO: After adding XPU BNB support, this function should be implemented
def cdequantize_blockwise_fp32(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_fp32 should not be called now.")

def cdequantize_blockwise_fp16_nf4(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_fp16_nf4 should not be called now.")

def cdequantize_blockwise_bf16_nf4(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_bf16_nf4 should not be called now.")

def cgemm_4bit_inference_naive_fp16(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cgemm_4bit_inference_naive_fp16 should not be called now.")

def cgemm_4bit_inference_naive_bf16(*args, **kwargs):
raise RuntimeError("XPU BNB support is not implemented yet. cgemm_4bit_inference_naive_bf16 should not be called now.")
else:
Expand Down Expand Up @@ -193,7 +193,7 @@ def get_lora_parameters(proj):
adapter = getattr(proj, "active_adapters", None)
if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
adapter = adapter[0]

return (
W,
getattr(W, "quant_state", None),
Expand Down Expand Up @@ -232,7 +232,7 @@ def get_lora_parameters_bias(proj):
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
@torch.inference_mode
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
# TODO: After adding XPU BNB support, check this function
# TODO: After adding XPU BNB support, check this function
if quant_state is None: return W
if type(quant_state) is not list:
# New quant_state as a class
Expand Down Expand Up @@ -535,7 +535,7 @@ def fast_gemv(X, W, quant_state, out = None):
device = W.device
device_index = device.index
CUDA_STREAM = CUDA_STREAMS[device_index]

# assert(dtype == X.dtype)
bout = shape[0]

Expand Down Expand Up @@ -669,7 +669,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
lora_A._fast_lora = lora_A.to(dtype)
lora_B._fast_lora = lora_B.to(dtype)
pass

if bsz == 1:
out = out.view(out_dim)
temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
Expand Down Expand Up @@ -709,6 +709,6 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None):
out.addmm_(XA, B.to(dtype), alpha = s)
# out += (X @ A.to(dtype)) @ (s * B.to(dtype))
pass

return out.view(batch, seq_len, -1) if reshape else out
pass
20 changes: 20 additions & 0 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,26 @@ def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO
accelerate.accelerator.Accelerator.distributed_type = lambda *args, **kwargs: DistributedType.NO
pass

# to move multiple tensors to the same device
def move_to_device(target_device, *tensors):
"""
Move multiple tensors to target device if they're not already there.

Args:
target_device: The target device to move tensors to
*tensors: Variable number of tensors to potentially move

Returns:
tuple: The tensors on the target device (same objects if already on device, new if moved)
"""
moved_tensors = []
for tensor in tensors:
if tensor.device != target_device:
moved_tensors.append(tensor.to(target_device))
else:
moved_tensors.append(tensor)
return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0]

import transformers.utils.quantization_config
transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__
# =============================================
Expand Down
16 changes: 10 additions & 6 deletions unsloth/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def CohereAttention_fast_forward(
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
Expand Down Expand Up @@ -254,6 +254,7 @@ def CohereAttention_fast_forward_inference(
do_prefill = False,
attention_mask = None,
):

Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
Expand Down Expand Up @@ -281,14 +282,14 @@ def CohereAttention_fast_forward_inference(
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")

# Mistral Nemo 12b has weird dimensions
if attention_size != hidden_size:
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0")
else:
self.temp_O = self.temp_QA[1][:,:,:hidden_size]
pass

self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
Expand Down Expand Up @@ -320,7 +321,7 @@ def CohereAttention_fast_forward_inference(

# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
cos, sin = self.rotary_emb.get_cached(kv_seq_len)
cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
h = self.half_head_dim
Expand All @@ -338,7 +339,7 @@ def CohereAttention_fast_forward_inference(
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)

# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
Expand Down Expand Up @@ -417,6 +418,9 @@ def CohereModel_fast_forward_inference(

next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
hidden_states, out_weight, position_ids = move_to_device(
decoder_layer._per_layer_device, hidden_states, out_weight, position_ids
)
residual = hidden_states
hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference(
Expand Down Expand Up @@ -468,7 +472,7 @@ def pre_patch():
CohereForCausalLM .forward = CausalLM_fast_forward(CohereModel_fast_forward_inference)
PeftModelForCausalLM .forward = PeftModel_fast_forward
fix_prepare_inputs_for_generation(CohereForCausalLM)

import transformers.models.cohere.modeling_cohere
transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = LlamaRotaryEmbedding
return
Expand Down
6 changes: 3 additions & 3 deletions unsloth/models/falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def FalconH1Attention_fast_forward(

if position_ids is None:
# Useful for LongRoPE
cos, sin = rotary_emb.get_cached(kv_seq_len)
cos, sin = rotary_emb.get_cached(kv_seq_len, device=Q.device)
else:
cos, sin = rotary_emb(V, seq_len = kv_seq_len)
cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device=Q.device)
Q, K = fast_rope_embedding(Q, K, cos, sin)

if past_key_value is not None:
Expand Down Expand Up @@ -280,7 +280,7 @@ def FalconH1Attention_fast_forward_inference(
# Need to do it prior 2 steps before hitting full on short KV cache
# or else error
self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2)
cos, sin = self.rotary_emb.get_cached(kv_seq_len)
cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
h = self.half_head_dim
Expand Down
50 changes: 34 additions & 16 deletions unsloth/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def GemmaModel_fast_forward_inference(

next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
hidden_states, out_weight, position_ids = move_to_device(
decoder_layer._per_layer_device, hidden_states, out_weight, position_ids
)

residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
Expand Down Expand Up @@ -224,9 +228,16 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
self.multi_gpu_cos_cached = [None]*DEVICE_COUNT
self.multi_gpu_sin_cached = [None]*DEVICE_COUNT

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
for device in range(DEVICE_COUNT):
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device), dtype=torch.get_default_dtype())

# dummy so that patch_utils doesn't fail for now
self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype())
self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype())
pass

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand All @@ -245,32 +256,38 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):

emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
cos = emb.cos().to(device = "cuda", non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = "cuda", non_blocking = True)#, dtype = dtype)
self.register_buffer("cos_cached", cos, persistent = False)
self.register_buffer("sin_cached", sin, persistent = False)
cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype)
self.multi_gpu_cos_cached[device.index] = cos
self.multi_gpu_sin_cached[device.index] = sin
return cos, sin
pass

def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.current_rope_size:
if seq_len is not None and seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

device_index = x.device.index

return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
self.multi_gpu_cos_cached[device_index][:seq_len].to(dtype=x.dtype),
self.multi_gpu_sin_cached[device_index][:seq_len].to(dtype=x.dtype),
)
pass

def get_cached(self, seq_len = None):
return self.cos_cached, self.sin_cached
def get_cached(self, seq_len = None, device = None):
if device is None:
device = torch.cuda.current_device()
return self.multi_gpu_cos_cached[device.index], self.multi_gpu_sin_cached[device.index]
pass

def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype)
for device in range(DEVICE_COUNT):
self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device), dtype = x.dtype)
pass
pass

Expand All @@ -288,7 +305,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
pass

def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len

Expand All @@ -304,10 +321,11 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):

emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
cos = emb.cos().to(device = "cuda", non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = "cuda", non_blocking = True)#, dtype = dtype)
self.register_buffer("cos_cached", cos, persistent = False)
self.register_buffer("sin_cached", sin, persistent = False)
cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype)
self.multi_gpu_cos_cached[device.index] = cos
self.multi_gpu_sin_cached[device.index] = sin
return cos, sin
pass
pass

Expand Down
Loading