diff --git a/unsloth/__init__.py b/unsloth/__init__.py index a1a39fd19..4da08da13 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -22,7 +22,7 @@ already_imported = [mod for mod in critical_modules if mod in sys.modules] # This check is critical because Unsloth optimizes these libraries by modifying -# their code at import time. If they're imported first, the original (slower, +# their code at import time. If they're imported first, the original (slower, # more memory-intensive) implementations will be used instead of Unsloth's # optimized versions, potentially causing OOM errors or slower training. @@ -73,6 +73,17 @@ def get_device_type(): pass DEVICE_TYPE : str = get_device_type() +def get_device_count(): + if DEVICE_TYPE == "cuda": + return torch.cuda.device_count() + elif DEVICE_TYPE == "xpu": + return torch.xpu.device_count() + else: + return 0 +pass + +DEVICE_COUNT : int = get_device_count() + # Reduce VRAM usage by reducing fragmentation # And optimize pinning of memory if DEVICE_TYPE == "cuda" and os.environ.get("UNSLOTH_VLLM_STANDBY", "0")=="0": @@ -237,4 +248,4 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 from .trainer import * # Patch TRL trainers for backwards compatibility -_patch_trl_trainer() \ No newline at end of file +_patch_trl_trainer() diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 1c65246b3..645319d42 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -18,7 +18,7 @@ next_power_of_2 = triton.next_power_of_2 import functools from typing import Optional -from unsloth import DEVICE_TYPE +from unsloth import DEVICE_TYPE, DEVICE_COUNT # torch.cuda.amp.custom_fwd is deprecated >= 2.4 import torch @@ -89,10 +89,11 @@ def get_ptr(x: Optional[torch.Tensor]): get_ptr = bnb.functional.get_ptr -if DEVICE_TYPE == "cuda" and torch.cuda.device_count() > 1: - torch_gpu_device = torch.cuda.device -elif DEVICE_TYPE == "xpu" and torch.xpu.device_count() > 1: - torch_gpu_device = torch.xpu.device +if DEVICE_COUNT > 1: + if DEVICE_TYPE == "cuda": + torch_gpu_device = torch.cuda.device + elif DEVICE_TYPE == "xpu": + torch_gpu_device = torch.xpu.device else: from contextlib import nullcontext def torch_gpu_device(device): return nullcontext() @@ -100,7 +101,7 @@ def torch_gpu_device(device): return nullcontext() # INTEL GPU Specific Logic if DEVICE_TYPE == "xpu": - _gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream + _gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream # NVIDIA GPU Default Logic else: _gpu_getCurrentRawStream = torch._C._cuda_getCurrentRawStream @@ -121,12 +122,12 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p: if DEVICE_TYPE == "xpu": _XPU_STREAMS = { (index := torch.xpu.device(i).idx) : ctypes.c_void_p(torch._C._xpu_getCurrentRawStream(index)) - for i in range(torch.xpu.device_count()) + for i in range(DEVICE_COUNT) } XPU_STREAMS = [None] * (max(_XPU_STREAMS.keys()) + 1) WEIGHT_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1) ABSMAX_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1) - for k, v in _XPU_STREAMS.items(): + for k, v in _XPU_STREAMS.items(): XPU_STREAMS[k] = v XPU_STREAMS = tuple(XPU_STREAMS) del _XPU_STREAMS @@ -134,7 +135,7 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p: # NVIDIA GPU Default Logic _CUDA_STREAMS = { (index := torch.cuda.device(i).idx) : ctypes.c_void_p(torch._C._cuda_getCurrentRawStream(index)) - for i in range(torch.cuda.device_count()) + for i in range(DEVICE_COUNT) } CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1) WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1) @@ -152,16 +153,16 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p: # TODO: After adding XPU BNB support, this function should be implemented def cdequantize_blockwise_fp32(*args, **kwargs): raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_fp32 should not be called now.") - + def cdequantize_blockwise_fp16_nf4(*args, **kwargs): raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_fp16_nf4 should not be called now.") - + def cdequantize_blockwise_bf16_nf4(*args, **kwargs): raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_bf16_nf4 should not be called now.") - + def cgemm_4bit_inference_naive_fp16(*args, **kwargs): raise RuntimeError("XPU BNB support is not implemented yet. cgemm_4bit_inference_naive_fp16 should not be called now.") - + def cgemm_4bit_inference_naive_bf16(*args, **kwargs): raise RuntimeError("XPU BNB support is not implemented yet. cgemm_4bit_inference_naive_bf16 should not be called now.") else: @@ -193,7 +194,7 @@ def get_lora_parameters(proj): adapter = getattr(proj, "active_adapters", None) if adapter is None: adapter = getattr(proj, "active_adapter", ("default")) adapter = adapter[0] - + return ( W, getattr(W, "quant_state", None), @@ -232,7 +233,7 @@ def get_lora_parameters_bias(proj): if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM: @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): - # TODO: After adding XPU BNB support, check this function + # TODO: After adding XPU BNB support, check this function if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class @@ -535,7 +536,7 @@ def fast_gemv(X, W, quant_state, out = None): device = W.device device_index = device.index CUDA_STREAM = CUDA_STREAMS[device_index] - + # assert(dtype == X.dtype) bout = shape[0] @@ -669,7 +670,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): lora_A._fast_lora = lora_A.to(dtype) lora_B._fast_lora = lora_B.to(dtype) pass - + if bsz == 1: out = out.view(out_dim) temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora) @@ -709,6 +710,6 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): out.addmm_(XA, B.to(dtype), alpha = s) # out += (X @ A.to(dtype)) @ (s * B.to(dtype)) pass - + return out.view(batch, seq_len, -1) if reshape else out pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index f576d17dc..c6ff7b626 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -77,7 +77,7 @@ import re import warnings, subprocess, re, inspect, psutil, os, math from unsloth_zoo.utils import Version -from unsloth_zoo import DEVICE_TYPE +from unsloth import DEVICE_TYPE, DEVICE_COUNT from unsloth_zoo.tokenizer_utils import ( patch_tokenizer as _patch_tokenizer, @@ -142,12 +142,6 @@ import logging logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1) -def get_device_num(): - if DEVICE_TYPE == "xpu": - return torch.xpu.device_count() - else: - return torch.cuda.device_count() - # Ignore logging messages class HideLoggingMessage(logging.Filter): __slots__ = "text", @@ -746,8 +740,7 @@ def get_statistics(): pass pass try: - devices = get_device_num() - _get_statistics(f"{devices if devices <= 8 else 9}") + _get_statistics(f"{DEVICE_COUNT if DEVICE_COUNT <= 8 else 9}") except: pass if disabled: enable_progress_bars() @@ -773,7 +766,7 @@ def get_statistics(): ) exec(BitsAndBytesConfig__init__, globals()) -if get_device_num() == 1: +if DEVICE_COUNT == 1: from accelerate.utils.dataclasses import DistributedType def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO import accelerate.state @@ -781,6 +774,36 @@ def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO accelerate.accelerator.Accelerator.distributed_type = lambda *args, **kwargs: DistributedType.NO pass +# to move multiple tensors to the same device +def move_to_device(target_device, *tensors): + """ + Move multiple tensors to target device if they're not already there. + + Args: + target_device: The target device to move tensors to + *tensors: Variable number of tensors to potentially move + + Returns: + tuple: The tensors on the target device (same objects if already on device, new if moved) + """ + if isinstance(target_device, int): + target_device = torch.device(target_device) + elif isinstance(target_device, str): + # if string we expect it to be a device name like "cuda:0" + target_device = torch.device(target_device) + elif isinstance(target_device, torch.device): + pass + else: + raise ValueError(f"Invalid target device: {target_device}") + pass + moved_tensors = [] + for tensor in tensors: + if tensor.device != target_device: + moved_tensors.append(tensor.to(target_device)) + else: + moved_tensors.append(tensor) + return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0] + import transformers.utils.quantization_config transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__ # ============================================= diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index bfa833be9..d4691fb5d 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -78,7 +78,7 @@ def CohereAttention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -254,6 +254,7 @@ def CohereAttention_fast_forward_inference( do_prefill = False, attention_mask = None, ): + Xn = hidden_states bsz, _, hd = hidden_states.size() K1, V1 = past_key_value @@ -281,14 +282,14 @@ def CohereAttention_fast_forward_inference( self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0") self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") - + # Mistral Nemo 12b has weird dimensions if attention_size != hidden_size: self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0") else: self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass - + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 @@ -320,7 +321,7 @@ def CohereAttention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) - cos, sin = self.rotary_emb.get_cached(kv_seq_len) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index) cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim @@ -338,7 +339,7 @@ def CohereAttention_fast_forward_inference( torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) @@ -397,7 +398,7 @@ def CohereModel_fast_forward_inference( position_ids, attention_mask = None, ): - out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") + out_weights = tuple(torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT)) input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype) @@ -417,8 +418,12 @@ def CohereModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + device_index = getattr(decoder_layer, "_per_layer_device_index", 0) + hidden_states, position_ids = move_to_device( + device_index, hidden_states, position_ids + ) residual = hidden_states - hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight) + hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weights[device_index]) hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference( decoder_layer.self_attn, hidden_states = hidden_states, @@ -435,7 +440,7 @@ def CohereModel_fast_forward_inference( next_decoder_cache.append(present_key_value) pass - hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weight) + hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weights[device_index]) return BaseModelOutputWithPast( last_hidden_state = hidden_states, @@ -468,7 +473,7 @@ def pre_patch(): CohereForCausalLM .forward = CausalLM_fast_forward(CohereModel_fast_forward_inference) PeftModelForCausalLM .forward = PeftModel_fast_forward fix_prepare_inputs_for_generation(CohereForCausalLM) - + import transformers.models.cohere.modeling_cohere transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = LlamaRotaryEmbedding return diff --git a/unsloth/models/falcon_h1.py b/unsloth/models/falcon_h1.py index 8978e4db0..2cbb78f8a 100644 --- a/unsloth/models/falcon_h1.py +++ b/unsloth/models/falcon_h1.py @@ -69,7 +69,7 @@ def FalconH1Attention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -110,12 +110,13 @@ def FalconH1Attention_fast_forward( # Extend RoPE dynamically to fit in VRA rotary_emb = self.rotary_emb rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) + device_index = Q.device.index if position_ids is None: # Useful for LongRoPE - cos, sin = rotary_emb.get_cached(kv_seq_len) + cos, sin = rotary_emb.get_cached(kv_seq_len, device_index) else: - cos, sin = rotary_emb(V, seq_len = kv_seq_len) + cos, sin = rotary_emb.get_cached(kv_seq_len, device_index) Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: @@ -245,14 +246,14 @@ def FalconH1Attention_fast_forward_inference( self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device) self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device) self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device) - + # Mistral Nemo 12b has weird dimensions if attention_size != hidden_size: self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) else: self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass - + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device) self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 @@ -280,7 +281,7 @@ def FalconH1Attention_fast_forward_inference( # Need to do it prior 2 steps before hitting full on short KV cache # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) - cos, sin = self.rotary_emb.get_cached(kv_seq_len) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index) cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim @@ -298,7 +299,7 @@ def FalconH1Attention_fast_forward_inference( RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) @@ -580,7 +581,7 @@ def _fast_prepare_inputs_for_generation( **kwargs,): # Overwitten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache` empty_past_kv = past_key_values is None - + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 8ad1c7e62..e43b205ec 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -149,7 +149,7 @@ def GemmaModel_fast_forward_inference( position_ids, attention_mask = None, ): - out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") + out_weights = tuple(torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT)) input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype) @@ -170,8 +170,13 @@ def GemmaModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + device_index = getattr(decoder_layer, "_per_layer_device_index", 0) + hidden_states, position_ids = move_to_device( + device_index, hidden_states, position_ids + ) + residual = hidden_states - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weights[device_index]) hidden_states, present_key_value = LlamaAttention_fast_forward_inference( decoder_layer.self_attn, hidden_states = hidden_states, @@ -183,13 +188,13 @@ def GemmaModel_fast_forward_inference( hidden_states += residual residual = hidden_states - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weights[device_index]) hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states) hidden_states += residual next_decoder_cache.append(present_key_value) pass - hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weights[device_index]) return BaseModelOutputWithPast( last_hidden_state = hidden_states, @@ -224,9 +229,16 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) + self.multi_gpu_cos_cached = [None]*DEVICE_COUNT + self.multi_gpu_sin_cached = [None]*DEVICE_COUNT # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) + for device in range(DEVICE_COUNT): + self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device), dtype=torch.get_default_dtype()) + + # dummy so that patch_utils doesn't fail for now + self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) pass def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -245,32 +257,38 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((radians_new, radians_new), dim = -1) # We must do RoPE in float32! - cos = emb.cos().to(device = "cuda", non_blocking = True)#, dtype = dtype) - sin = emb.sin().to(device = "cuda", non_blocking = True)#, dtype = dtype) - self.register_buffer("cos_cached", cos, persistent = False) - self.register_buffer("sin_cached", sin, persistent = False) + cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype) + sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype) + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin + return cos, sin pass def forward(self, x, position_ids=None, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.current_rope_size: + if seq_len is not None and seq_len > self.current_rope_size: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + device_index = x.device.index + return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), + self.multi_gpu_cos_cached[device_index][:seq_len], + self.multi_gpu_sin_cached[device_index][:seq_len], ) pass - def get_cached(self, seq_len = None): - return self.cos_cached, self.sin_cached + def get_cached(self, seq_len = None, device_index = None): + if device_index is None: + device_index = torch.cuda.current_device() + return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[device_index] pass def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = math.ceil(seq_len / 8192) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype) + for device in range(DEVICE_COUNT): + self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device), dtype = x.dtype) pass pass @@ -288,7 +306,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= pass def _set_cos_sin_cache(self, seq_len, device, dtype): -# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and + # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and # in FP32. They are applied (multiplied) in FP32 as well. self.current_rope_size = seq_len @@ -304,10 +322,11 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((radians_new, radians_new), dim = -1) # We must do RoPE in float32! - cos = emb.cos().to(device = "cuda", non_blocking = True)#, dtype = dtype) - sin = emb.sin().to(device = "cuda", non_blocking = True)#, dtype = dtype) - self.register_buffer("cos_cached", cos, persistent = False) - self.register_buffer("sin_cached", sin, persistent = False) + cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype) + sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype) + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin + return cos, sin pass pass diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 23b91ff6f..5597995b0 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -84,7 +84,7 @@ def Gemma2Attention_fast_forward( padding_mask: Optional[torch.LongTensor] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -113,12 +113,13 @@ def Gemma2Attention_fast_forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] + device_index = Q.device.index if position_ids is None: - cos = self.rotary_emb.cos_cached - sin = self.rotary_emb.sin_cached + cos = self.rotary_emb.multi_gpu_cos_cached[device_index] + sin = self.rotary_emb.multi_gpu_sin_cached[device_index] Q, K = fast_rope_embedding(Q, K, cos, sin) else: - cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, device_index) Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) pass @@ -281,7 +282,7 @@ def Gemma2Attention_fast_forward_inference( # Only for Gemma2 self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device) - + # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below # We default to using the config file itself @@ -307,8 +308,9 @@ def Gemma2Attention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) - cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1) - sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim RH_Q = self.RH_Q @@ -324,7 +326,7 @@ def Gemma2Attention_fast_forward_inference( torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) @@ -386,7 +388,7 @@ def Gemma2Model_fast_forward_inference( position_ids, attention_mask = None, ): - out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") + out_weights = tuple(torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT)) input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype) @@ -422,10 +424,17 @@ def Gemma2Model_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + # For pipeline parallelism, we need to move all tensors to the same device + # note that this movement is once per GPU in PP + device_index = getattr(decoder_layer, "_per_layer_device_index", 0) + hidden_states, position_ids = move_to_device( + device_index, hidden_states, position_ids + ) + use_sliding_window = idx % 2 == 0 residual = hidden_states - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weights[device_index]) hidden_states, present_key_value = Gemma2Attention_fast_forward_inference( decoder_layer.self_attn, hidden_states = hidden_states, @@ -435,18 +444,18 @@ def Gemma2Model_fast_forward_inference( do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), use_sliding_window = use_sliding_window, ) - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weights[device_index]) hidden_states += residual residual = hidden_states - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weights[device_index]) hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states) - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weights[device_index]) hidden_states += residual next_decoder_cache.append(present_key_value) pass - hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weights[device_index]) return BaseModelOutputWithPast( last_hidden_state = hidden_states, @@ -479,7 +488,7 @@ def pre_patch(): Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference) PeftModelForCausalLM .forward = PeftModel_fast_forward fix_prepare_inputs_for_generation(Gemma2ForCausalLM) - + # Solves https://github.com/unslothai/unsloth/issues/168 # Static KV Cache was introduced in 4.38.0, causing training to be much slower. # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings. diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 243922fc1..a3d79c833 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -71,7 +71,7 @@ def GraniteAttention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -162,7 +162,7 @@ def GraniteAttention_fast_forward( # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass - + attn_output = A.reshape(bsz, q_len, n_heads*head_dim) attn_output = self.apply_o(self, attn_output) attn_weights = None @@ -256,7 +256,7 @@ def GraniteAttention_fast_forward_inference( use_sliding_window = False, position_embeddings : Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): - + assert position_embeddings is not None, f"Granite model requires position embeddings to be specified" Xn = hidden_states @@ -326,7 +326,7 @@ def GraniteAttention_fast_forward_inference( torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) @@ -349,7 +349,7 @@ def GraniteAttention_fast_forward_inference( Qn *= self.scaling A = torch_matmul(Qn, Kn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) - + # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) @@ -395,10 +395,14 @@ def GraniteModel_fast_forward_inference( attention_mask = None pass - position_embeddings = self.model.rotary_emb(hidden_states, position_ids, self.max_seq_length) + position_embeddings = self.model.rotary_emb.get_cached(self.max_seq_length, hidden_states.device.index) next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + device_index = getattr(decoder_layer, "_per_layer_device_index", 0) + hidden_states, position_ids = move_to_device( + device_index, hidden_states, position_ids + ) residual = hidden_states hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states) @@ -532,7 +536,7 @@ def post_patch(model, tokenizer): elif hasattr(module, "short_cos_cached") and \ (module.short_cos_cached.dtype != correct_dtype): - + module.short_cos_cached = module.short_cos_cached.to(correct_dtype) module.short_sin_cached = module.short_sin_cached.to(correct_dtype) pass @@ -547,4 +551,3 @@ def post_patch(model, tokenizer): return model, tokenizer pass pass - diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1253e641e..db0e8843c 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -20,11 +20,12 @@ from ._utils import * from ._utils import patch_unsloth_smart_gradient_checkpointing from ._utils import __version__ +from ._utils import move_to_device from torch.nn.functional import scaled_dot_product_attention from transformers import __version__ as transformers_version from unsloth_zoo.utils import Version, _get_dtype from unsloth_zoo.peft_utils import SKIP_QUANTIZATION_MODULES -from unsloth import DEVICE_TYPE +from unsloth import DEVICE_TYPE, DEVICE_COUNT transformers_version = Version(transformers_version) # Transformers moved rotary embeddings out of all attention layers @@ -121,12 +122,12 @@ def _fast_prepare_inputs_for_generation(self, input_ids, attention_mask=None, ** else: bs, cache_length = input_ids.shape input_ids = input_ids[:,[-1]] - + # Get to the base model base_model = self if hasattr(base_model, 'base_model_prefix'): base_model = getattr(base_model, base_model.base_model_prefix) - + if hasattr(base_model, "_prepare_4d_causal_attention_mask_with_cache_position"): def needs_device_kw(fn) -> bool: try: @@ -243,14 +244,14 @@ def LlamaAttention_fast_forward_inference( self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device) self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device) self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device) - + # Mistral Nemo 12b has weird dimensions if attention_size != hidden_size: self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) else: self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass - + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device) self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 @@ -274,7 +275,7 @@ def LlamaAttention_fast_forward_inference( # Need to do it prior 2 steps before hitting full on short KV cache # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) - cos, sin = self.rotary_emb.get_cached(kv_seq_len) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index) cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim @@ -292,7 +293,7 @@ def LlamaAttention_fast_forward_inference( RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) @@ -361,10 +362,10 @@ def fast_swiglu_inference(self, X, temp_gate = None, temp_up = None, gate_multip # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") gate = fast_linear_forward(self.gate_proj, X, out = temp_gate) - + if gate_multiplier is not None: gate *= gate_multiplier - + up = fast_linear_forward(self. up_proj, X, out = temp_up) gate = torch_nn_functional_silu(gate, inplace = True) @@ -447,7 +448,7 @@ def LlamaAttention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -483,11 +484,12 @@ def LlamaAttention_fast_forward( rotary_emb = self.rotary_emb rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) - if position_ids is None: - # Useful for LongRoPE - cos, sin = rotary_emb.get_cached(kv_seq_len) - else: - cos, sin = rotary_emb(V, seq_len = kv_seq_len) + # if position_ids is None: + # # Useful for LongRoPE + # cos, sin = rotary_emb.get_cached(kv_seq_len, device = Q.device) + # else: + # cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) + cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index) # Q, K = ( # fast_rope_embedding(Q, K, cos, sin) @@ -672,7 +674,7 @@ def LlamaModel_fast_forward( return_dict: Optional[bool] = None, *args, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: - + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions assert(output_attentions is False) output_hidden_states = ( @@ -707,7 +709,7 @@ def LlamaModel_fast_forward( inputs_embeds = inputs_embeds[:,:self.max_seq_length,:] pass pass - + past_key_values_length = 0 if past_key_values is not None: @@ -794,7 +796,7 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif self.training and os.environ.get("UNSLOTH_KEEP_PADDING", "0") != '1': + elif self.training and os.environ.get("UNSLOTH_KEEP_PADDING", "0") != '1': attention_mask = None padding_mask = None else: @@ -911,7 +913,7 @@ def LlamaModel_fast_forward( # Also, transformers 4.45.0 supports granite but with the attention refactor (it always had the refactor) # unsloth's check for granite too has "version >= 4.45.0 (rightly so)". # so let granite always use the attention refactor implementation. - position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) + position_embeddings = self.rotary_emb.get_cached(self.config.max_position_embeddings, hidden_states.device.index) else: position_embeddings = None @@ -1021,7 +1023,7 @@ def LlamaModel_fast_forward_inference_custom( XX, XX2 = _XX[0], _XX[1] variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = f"{DEVICE_TYPE}:0") temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = f"{DEVICE_TYPE}:0") - temp_gate, temp_up = temp_mlp[0], temp_mlp[1] + temp_gates, temp_ups = tuple(temp_mlp[0].to(torch.device(x)) for x in range(DEVICE_COUNT)), tuple(temp_mlp[1].to(torch.device(x)) for x in range(DEVICE_COUNT)) seq_len = past_key_values[0][0].shape[-2] if bsz != 1: @@ -1039,6 +1041,10 @@ def LlamaModel_fast_forward_inference_custom( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + device_index = getattr(decoder_layer, "_per_layer_device_index", 0) + X, residual, position_ids = move_to_device( + device_index, X, residual, position_ids + ) residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( decoder_layer.input_layernorm, @@ -1068,8 +1074,8 @@ def LlamaModel_fast_forward_inference_custom( X = mlp_fast_forward_inference( decoder_layer.mlp, X, - temp_gate = temp_gate, - temp_up = temp_up, + temp_gate = temp_gates[device_index], + temp_up = temp_ups[device_index], ) X += residual @@ -1154,7 +1160,7 @@ def _CausalLM_fast_forward( logit_scaling = getattr(self.config, "logit_scale", 0) dtype = lm_head.dtype num_logits_to_keep = max(num_logits_to_keep, logits_to_keep) - + # Move items to same device as lm_head hidden_states = hidden_states.to(lm_head_device) if labels is not None: labels = labels.to(lm_head_device) @@ -1181,7 +1187,7 @@ def _CausalLM_fast_forward( RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1" # < 1024 Normal Unsloth uses less VRAM! if bsz*q_len <= 1024: RETURN_LOGITS = True - + if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) @@ -1295,15 +1301,15 @@ def PeftModel_fast_forward( **kwargs, ): is_classification = "Classification" in str(type(self.base_model.model)) - if is_classification: + if is_classification: return self.base_model( input_ids = input_ids, - attention_mask = attention_mask, - inputs_embeds = inputs_embeds, - labels = labels, + attention_mask = attention_mask, + inputs_embeds = inputs_embeds, + labels = labels, output_attentions = output_attentions, - output_hidden_states = output_hidden_states, - return_dict = return_dict, + output_hidden_states = output_hidden_states, + return_dict = return_dict, **kwargs, ) else: @@ -1351,9 +1357,16 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) + self.multi_gpu_cos_cached = [None]*DEVICE_COUNT + self.multi_gpu_sin_cached = [None]*DEVICE_COUNT # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) + for device_idx in range(DEVICE_COUNT): + self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype()) + + # dummy so that patch_utils doesn't fail for now + self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) pass def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -1368,30 +1381,37 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False) + cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) + sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin + return cos, sin pass def forward(self, x, position_ids=None, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.current_rope_size: + if seq_len is not None and seq_len > self.current_rope_size: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + device_index = x.device.index return ( - self.cos_cached[:seq_len].to(dtype = x.dtype), - self.sin_cached[:seq_len].to(dtype = x.dtype), + self.multi_gpu_cos_cached[device_index][:seq_len], + self.multi_gpu_sin_cached[device_index][:seq_len], ) pass - def get_cached(self, seq_len = None): - return self.cos_cached, self.sin_cached + def get_cached(self, seq_len = None, device_index = None): + if device_index is None: + device_index = torch.cuda.current_device() + return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[device_index] pass def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype) + for device_idx in range(DEVICE_COUNT): + self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass pass @@ -1419,8 +1439,11 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False) + cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) + sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin + return cos, sin pass pass @@ -1446,6 +1469,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) + self.multi_gpu_cos_cached = [None]*DEVICE_COUNT + self.multi_gpu_sin_cached = [None]*DEVICE_COUNT # Normal Llama-3 RoPE inv_freq = 1.0 / ( @@ -1455,21 +1480,54 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.register_buffer("inv_freq", inv_freq, persistent = False) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) + for device_idx in range(DEVICE_COUNT): + self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype()) + + # dummy so that patch_utils doesn't fail for now + self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) pass def _set_cos_sin_cache(self, seq_len, device, dtype): # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and # in FP32. They are applied (multiplied) in FP32 as well. self.current_rope_size = seq_len - + t = torch.arange(self.current_rope_size, device=self.inv_freq.device, dtype=torch.int64).float() freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False) + cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) + sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin + return cos, sin + pass + + def forward(self, x, position_ids=None, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len is not None and seq_len > self.current_rope_size: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + device_index = x.device.index + return ( + self.multi_gpu_cos_cached[device_index][:seq_len], + self.multi_gpu_sin_cached[device_index][:seq_len], + ) + pass + + def get_cached(self, seq_len = None, device_index = None): + if device_index is None: + device_index = torch.cuda.current_device() + return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[device_index] + pass + + def extend_rope_embedding(self, x, seq_len): + if seq_len <= self.current_rope_size: return + # Iteratively grow by increments of 8192 + self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 + for device_idx in range(DEVICE_COUNT): + self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass # From https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L41 @@ -1497,28 +1555,6 @@ def apply_scaling(self, freqs: torch.Tensor): new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) pass - - def forward(self, x, position_ids=None, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.current_rope_size: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype = x.dtype), - self.sin_cached[:seq_len].to(dtype = x.dtype), - ) - pass - - def get_cached(self, seq_len = None): - return self.cos_cached, self.sin_cached - pass - - def extend_rope_embedding(self, x, seq_len): - if seq_len <= self.current_rope_size: return - # Iteratively grow by increments of 8192 - self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype) - pass pass @@ -1554,6 +1590,10 @@ def __init__(self, self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(original_max_position_embeddings, self.max_position_embeddings) + self.multi_gpu_short_cos_cached = [None]*DEVICE_COUNT + self.multi_gpu_short_sin_cached = [None]*DEVICE_COUNT + self.multi_gpu_long_cos_cached = [None]*DEVICE_COUNT + self.multi_gpu_long_sin_cached = [None]*DEVICE_COUNT # Long RoPE similar to RoPE except short sequences have 1 cos / sin # and long sequences have another cos / sin @@ -1575,64 +1615,78 @@ def __init__(self, # Short and long inv_freq self.register_buffer("short_inv_freq", short_inv_freq, persistent = False) self.register_buffer("long_inv_freq", long_inv_freq, persistent = False) - # Build here to make `torch.jit.trace` work. - # self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) - # Short sequences + # Build here to make `torch.jit.trace` work. + # Initialize short sequences cache for all devices dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16 t = torch.arange(original_max_position_embeddings, device=self.short_inv_freq.device, dtype=torch.int64).float() freqs = torch.outer(t, self.short_inv_freq) emb = torch.cat((freqs, freqs), dim=-1) - cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) - sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) - self.register_buffer("short_cos_cached", cos_cached, persistent=False) - self.register_buffer("short_sin_cached", sin_cached, persistent=False) + + for device_idx in range(DEVICE_COUNT): + device_obj = torch.device(device_idx) + cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) + sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) + self.multi_gpu_short_cos_cached[device_idx] = cos_cached + self.multi_gpu_short_sin_cached[device_idx] = sin_cached + + # dummy so that patch_utils doesn't fail for now + self.short_cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.short_sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.long_cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.long_sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) pass def _set_cos_sin_cache(self, seq_len, device, dtype): # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and # in FP32. They are applied (multiplied) in FP32 as well. self.current_rope_size = seq_len - + t = torch.arange(self.current_rope_size, device=self.long_inv_freq.device, dtype=torch.int64).float() # Long sequences freqs = torch.outer(t, self.long_inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) - self.register_buffer("long_cos_cached", cos_cached, persistent=False) - self.register_buffer("long_sin_cached", sin_cached, persistent=False) + self.multi_gpu_long_cos_cached[device.index] = cos_cached + self.multi_gpu_long_sin_cached[device.index] = sin_cached + return cos_cached, sin_cached pass def forward(self, x, position_ids=None, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.current_rope_size: + if seq_len is not None and seq_len > self.current_rope_size: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - if seq_len < self.original_max_position_embeddings: + device_index = x.device.index + + if seq_len is not None and seq_len < self.original_max_position_embeddings: return ( - self.short_cos_cached[:seq_len].to(dtype = x.dtype), - self.short_sin_cached[:seq_len].to(dtype = x.dtype), + self.multi_gpu_short_cos_cached[device_index][:seq_len], + self.multi_gpu_short_sin_cached[device_index][:seq_len], ) else: return ( - self.long_cos_cached[:seq_len].to(dtype = x.dtype), - self.long_sin_cached[:seq_len].to(dtype = x.dtype), + self.multi_gpu_long_cos_cached[device_index][:seq_len], + self.multi_gpu_long_sin_cached[device_index][:seq_len], ) pass pass - def get_cached(self, seq_len = None): - if seq_len < self.original_max_position_embeddings: - return self.short_cos_cached, self.short_sin_cached - return self.long_cos_cached, self.long_sin_cached + def get_cached(self, seq_len = None, device_index = None): + if device_index is None: + device_index = torch.cuda.current_device() + if seq_len is not None and seq_len < self.original_max_position_embeddings: + return self.multi_gpu_short_cos_cached[device_index], self.multi_gpu_short_sin_cached[device_index] + return self.multi_gpu_long_cos_cached[device_index], self.multi_gpu_long_sin_cached[device_index] pass def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype) + for device_idx in range(DEVICE_COUNT): + self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass pass @@ -1754,7 +1808,7 @@ def from_pretrained( max_lora_rank = 16, disable_log_stats = False, unsloth_vllm_standby = False, - num_labels = None, + num_labels = None, **kwargs, ): os.environ["UNSLOTH_USE_NEW_MODEL"] = "0" @@ -1787,7 +1841,6 @@ def from_pretrained( gpu_stats = torch.cuda.get_device_properties(0) gpu_version = torch.version.cuda gpu_stats_snippet = f"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}." - num_gpus = torch.cuda.device_count() from importlib.metadata import version as importlib_version try: vllm_version = f" vLLM: {importlib_version('vllm')}." @@ -1795,7 +1848,6 @@ def from_pretrained( elif DEVICE_TYPE == "xpu": gpu_stats = torch.xpu.get_device_properties(0) gpu_version = torch.version.xpu - num_gpus = torch.xpu.device_count() gpu_stats_snippet = f"Intel Toolkit: {gpu_version}." try: vllm_version = f" vLLM: {importlib_version('vllm')}." @@ -1807,7 +1859,7 @@ def from_pretrained( statistics = \ f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.{vllm_version}\n"\ - f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {num_gpus}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ + f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {DEVICE_COUNT}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\n"\ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ f' "-____-" Free license: http://github.com/unslothai/unsloth' @@ -1827,7 +1879,7 @@ def from_pretrained( if old_hf_transfer != "0": os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" model_patcher.pre_patch() - get_statistics() # For debugging - we use a download counter to see if environments are not breaking + get_statistics() # For debugging - we use a download counter to see if environments are not breaking if dtype is None: dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 @@ -1914,7 +1966,7 @@ def from_pretrained( # Cannot be None, since HF now checks for the config if load_in_4bit: kwargs["quantization_config"] = bnb_config - + if num_labels is not None: model = AutoModelForSequenceClassification.from_pretrained( model_name, @@ -2012,7 +2064,7 @@ def from_pretrained( except: raise RuntimeError('Unsloth: Unsuccessfully patched inner_training_loop') pass - + import transformers.trainer items_in_trainer = dir(transformers.trainer) good_items = [] @@ -2362,7 +2414,7 @@ def get_peft_model( ) loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1) pass - + if hasattr(model.config, "quantization_config"): raise ValueError( "Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"\ @@ -2492,7 +2544,7 @@ def get_peft_model( is_classification = "Classification" in str(type(model)) # Get LoRA - # + # arguments = dict( r = r, @@ -2518,7 +2570,7 @@ def get_peft_model( input_embeddings_device = model.get_input_embeddings().weight.device if is_classification: output_embeddings_device = model.score.weight.device - else: + else: output_embeddings_device = model.get_output_embeddings().weight.device if use_gradient_checkpointing == "unsloth": @@ -2678,7 +2730,7 @@ def patch_peft_model( # model.peft_config[active_adapter].revision = f"unsloth" pass - from transformers.trainer import Trainer + from transformers.trainer import Trainer if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop": raise RuntimeError("Unsloth: Unsuccessfully patched Trainer! Please file a bug report!") pass @@ -2812,7 +2864,7 @@ def patch_peft_model( internal_model.max_seq_length = max_seq_length internal_model = internal_model.model pass - internal_model.max_seq_length = max_seq_length + internal_model.max_seq_length = max_seq_length # Patch tokenizer to pad to the right internal_model = model diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index a3e07e3b0..68d4ba43f 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -51,7 +51,7 @@ def MistralAttention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -83,12 +83,10 @@ def MistralAttention_fast_forward( # Extend RoPE dynamically to fit in VRAM self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, Q.device.index) if position_ids is None: - cos = self.rotary_emb.cos_cached - sin = self.rotary_emb.sin_cached Q, K = fast_rope_embedding(Q, K, cos, sin) else: - cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) pass @@ -162,7 +160,7 @@ def MistralAttention_fast_forward( # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass - + attn_output = A.reshape(bsz, q_len, n_heads*head_dim) attn_output = self.apply_o(self, attn_output) attn_weights = None @@ -201,7 +199,7 @@ def MistralForCausalLM_fast_forward( causal_mask = xformers.attn_bias.BlockDiagonalCausalMask\ .from_seqlens([q_len]*bsz)\ .make_local_attention(window_size = sliding_window) - + elif not HAS_XFORMERS and attention_mask is None: if sliding_window is None or sliding_window == "null" or sliding_window <= 0 or q_len <= sliding_window: # Fully causal mask @@ -212,10 +210,10 @@ def MistralForCausalLM_fast_forward( # Sliding window attention q_indices = torch.arange(q_len, device=input_ids.device).view(-1, 1) k_indices = torch.arange(q_len, device=input_ids.device).view(1, -1) - + causal_bool_mask = k_indices <= q_indices window_bool_mask = (q_indices - k_indices) < sliding_window - + mask = torch.where(causal_bool_mask & window_bool_mask, 0.0, -torch.inf) attention_mask = mask[None, None, :, :].expand(bsz, 1, q_len, q_len) @@ -258,7 +256,7 @@ def MistralForCausalLM_fast_forward( bsz, q_len, hd = hidden_states.shape lm_head = self.lm_head.weight lm_head_device = lm_head.device - + # Move items to same device as lm_head hidden_states = hidden_states.to(lm_head_device) if labels is not None: labels = labels.to(lm_head_device) @@ -301,7 +299,7 @@ def MistralForCausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - + output = CausalLMOutputWithPast( loss = loss, logits = EMPTY_LOGITS, @@ -390,7 +388,7 @@ def pre_patch(): MistralForCausalLM .forward = MistralForCausalLM_fast_forward PeftModelForCausalLM .forward = PeftModel_fast_forward fix_prepare_inputs_for_generation(MistralForCausalLM) - + # Solves https://github.com/unslothai/unsloth/issues/168 # Static KV Cache was introduced in 4.38.0, causing training to be much slower. # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings. diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py index 80bd6ee7c..b20f22dab 100644 --- a/unsloth/models/qwen3.py +++ b/unsloth/models/qwen3.py @@ -66,7 +66,7 @@ def Qwen3Attention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -111,12 +111,13 @@ def Qwen3Attention_fast_forward( # Extend RoPE dynamically to fit in VRA rotary_emb = self.rotary_emb rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) + device_index = Q.device.index if position_ids is None: # Useful for LongRoPE - cos, sin = rotary_emb.get_cached(kv_seq_len) + cos, sin = rotary_emb.get_cached(kv_seq_len, device_index) else: - cos, sin = rotary_emb(V, seq_len = kv_seq_len) + cos, sin = rotary_emb.get_cached(kv_seq_len, device_index) Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: @@ -260,14 +261,14 @@ def Qwen3Attention_fast_forward_inference( self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device) self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device) self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device) - + # Mistral Nemo 12b has weird dimensions if attention_size != hidden_size: self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) else: self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass - + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device) self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 @@ -297,7 +298,7 @@ def Qwen3Attention_fast_forward_inference( # Need to do it prior 2 steps before hitting full on short KV cache # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) - cos, sin = self.rotary_emb.get_cached(kv_seq_len) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index) cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim @@ -315,7 +316,7 @@ def Qwen3Attention_fast_forward_inference( RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 53c549742..a358594d8 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -61,7 +61,7 @@ # Old HF Hub versions <= 0.0.25 from huggingface_hub.utils._token import get_token pass -from unsloth import DEVICE_TYPE +from unsloth import DEVICE_TYPE, DEVICE_COUNT __all__ = [ "FastBaseModel", @@ -281,7 +281,6 @@ def from_pretrained( gpu_stats = torch.cuda.get_device_properties(0) gpu_version = torch.version.cuda gpu_stats_snippet = f"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}." - num_gpus = torch.cuda.device_count() from importlib.metadata import version as importlib_version try: vllm_version = f" vLLM: {importlib_version('vllm')}." @@ -289,7 +288,6 @@ def from_pretrained( elif DEVICE_TYPE == "xpu": gpu_stats = torch.xpu.get_device_properties(0) gpu_version = torch.version.xpu - num_gpus = torch.xpu.device_count() gpu_stats_snippet = f"Intel Toolkit: {gpu_version}." # TODO: After adding vLLM support for XPU, changed this @@ -306,11 +304,11 @@ def from_pretrained( statistics = \ f"==((====))== Unsloth {__version__}: Fast {model_type_arch.title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\ - f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {num_gpus}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ + f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {DEVICE_COUNT}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\n"\ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ f' "-____-" Free license: http://github.com/unslothai/unsloth' - + print(statistics) # Warn about fast transfers @@ -325,7 +323,7 @@ def from_pretrained( pass if old_hf_transfer != "0": os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" - get_statistics() # For debugging - we use a download counter to see if environments are not breaking + get_statistics() # For debugging - we use a download counter to see if environments are not breaking if dtype is None: dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 @@ -604,7 +602,7 @@ def get_peft_model( else: assert(type(target_modules) in (list, tuple,)) pass - + # Clear deleted GPU items for _ in range(3): gc.collect() @@ -678,7 +676,7 @@ def post_patch_model( float32_mixed_precision = float32_mixed_precision, ) - from transformers.trainer import Trainer + from transformers.trainer import Trainer if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop" and trust_remote_code == False: raise RuntimeError('Unsloth: Unsuccessfully patched inner_training_loop') pass