Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,26 @@ def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO
accelerate.accelerator.Accelerator.distributed_type = lambda *args, **kwargs: DistributedType.NO
pass

# to move multiple tensors to the same device
def move_to_device(target_device, *tensors):
"""
Move multiple tensors to target device if they're not already there.

Args:
target_device: The target device to move tensors to
*tensors: Variable number of tensors to potentially move

Returns:
tuple: The tensors on the target device (same objects if already on device, new if moved)
"""
moved_tensors = []
for tensor in tensors:
if tensor.device != target_device:
moved_tensors.append(tensor.to(target_device))
else:
moved_tensors.append(tensor)
return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0]

import transformers.utils.quantization_config
transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__
# =============================================
Expand Down
16 changes: 10 additions & 6 deletions unsloth/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def CohereAttention_fast_forward(
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
Expand Down Expand Up @@ -254,6 +254,7 @@ def CohereAttention_fast_forward_inference(
do_prefill = False,
attention_mask = None,
):

Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
Expand Down Expand Up @@ -281,14 +282,14 @@ def CohereAttention_fast_forward_inference(
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")

# Mistral Nemo 12b has weird dimensions
if attention_size != hidden_size:
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0")
else:
self.temp_O = self.temp_QA[1][:,:,:hidden_size]
pass

self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
Expand Down Expand Up @@ -320,7 +321,7 @@ def CohereAttention_fast_forward_inference(

# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
cos, sin = self.rotary_emb.get_cached(kv_seq_len)
cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
h = self.half_head_dim
Expand All @@ -338,7 +339,7 @@ def CohereAttention_fast_forward_inference(
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)

# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
Expand Down Expand Up @@ -417,6 +418,9 @@ def CohereModel_fast_forward_inference(

next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
hidden_states, out_weight, position_ids = move_to_device(
decoder_layer._per_layer_device, hidden_states, out_weight, position_ids
)
residual = hidden_states
hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference(
Expand Down Expand Up @@ -468,7 +472,7 @@ def pre_patch():
CohereForCausalLM .forward = CausalLM_fast_forward(CohereModel_fast_forward_inference)
PeftModelForCausalLM .forward = PeftModel_fast_forward
fix_prepare_inputs_for_generation(CohereForCausalLM)

import transformers.models.cohere.modeling_cohere
transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = LlamaRotaryEmbedding
return
Expand Down
6 changes: 3 additions & 3 deletions unsloth/models/falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def FalconH1Attention_fast_forward(

if position_ids is None:
# Useful for LongRoPE
cos, sin = rotary_emb.get_cached(kv_seq_len)
cos, sin = rotary_emb.get_cached(kv_seq_len, device=Q.device)
else:
cos, sin = rotary_emb(V, seq_len = kv_seq_len)
cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device=Q.device)
Q, K = fast_rope_embedding(Q, K, cos, sin)

if past_key_value is not None:
Expand Down Expand Up @@ -280,7 +280,7 @@ def FalconH1Attention_fast_forward_inference(
# Need to do it prior 2 steps before hitting full on short KV cache
# or else error
self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2)
cos, sin = self.rotary_emb.get_cached(kv_seq_len)
cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
h = self.half_head_dim
Expand Down
48 changes: 32 additions & 16 deletions unsloth/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def GemmaModel_fast_forward_inference(

next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
hidden_states, out_weight, position_ids = move_to_device(
decoder_layer._per_layer_device, hidden_states, out_weight, position_ids
)

residual = hidden_states
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
Expand Down Expand Up @@ -224,9 +228,16 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
self.multi_gpu_cos_cached = [None]*torch.cuda.device_count()
self.multi_gpu_sin_cached = [None]*torch.cuda.device_count()

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
for device in range(torch.cuda.device_count()):
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device), dtype=torch.get_default_dtype())

# dummy so that patch_utils doesn't fail for now
self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype())
self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype())
pass

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand All @@ -245,32 +256,36 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):

emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
cos = emb.cos().to(device = "cuda", non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = "cuda", non_blocking = True)#, dtype = dtype)
self.register_buffer("cos_cached", cos, persistent = False)
self.register_buffer("sin_cached", sin, persistent = False)
cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype)
self.multi_gpu_cos_cached[device.index] = cos
self.multi_gpu_sin_cached[device.index] = sin
return cos, sin
pass

def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.current_rope_size:
if seq_len is not None and seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
self.multi_gpu_cos_cached[x.device.index][:seq_len].to(dtype=x.dtype),
self.multi_gpu_sin_cached[x.device.index][:seq_len].to(dtype=x.dtype),
)
pass

def get_cached(self, seq_len = None):
return self.cos_cached, self.sin_cached
def get_cached(self, seq_len = None, device = None):
if device is None:
device = torch.cuda.current_device()
return self.multi_gpu_cos_cached[device.index], self.multi_gpu_sin_cached[device.index]
pass

def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype)
for device in range(torch.cuda.device_count()):
self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device), dtype = x.dtype)
pass
pass

Expand All @@ -288,7 +303,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=
pass

def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len

Expand All @@ -304,10 +319,11 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):

emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
cos = emb.cos().to(device = "cuda", non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = "cuda", non_blocking = True)#, dtype = dtype)
self.register_buffer("cos_cached", cos, persistent = False)
self.register_buffer("sin_cached", sin, persistent = False)
cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype)
self.multi_gpu_cos_cached[device.index] = cos
self.multi_gpu_sin_cached[device.index] = sin
return cos, sin
pass
pass

Expand Down
25 changes: 16 additions & 9 deletions unsloth/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def Gemma2Attention_fast_forward(
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
Expand Down Expand Up @@ -114,11 +114,11 @@ def Gemma2Attention_fast_forward(
kv_seq_len += past_key_value[0].shape[-2]

if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
cos = self.rotary_emb.multi_gpu_cos_cached[Q.device.index]
sin = self.rotary_emb.multi_gpu_sin_cached[Q.device.index]
Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
cos, sin = self.rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass

Expand Down Expand Up @@ -281,7 +281,7 @@ def Gemma2Attention_fast_forward_inference(
# Only for Gemma2
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device)

# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
# We default to using the config file itself
Expand All @@ -307,8 +307,9 @@ def Gemma2Attention_fast_forward_inference(

# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1)
sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1)
cos, sin = self.rotary_emb.get_cached(seq_len = kv_seq_len, device = Qn.device)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
h = self.half_head_dim

RH_Q = self.RH_Q
Expand All @@ -324,7 +325,7 @@ def Gemma2Attention_fast_forward_inference(
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)

# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
Expand Down Expand Up @@ -422,6 +423,12 @@ def Gemma2Model_fast_forward_inference(
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):

# For pipeline parallelism, we need to move all tensors to the same device
# note that this movement is once per GPU in PP
hidden_states, out_weight, position_ids = move_to_device(
decoder_layer._per_layer_device, hidden_states, out_weight, position_ids
)

use_sliding_window = idx % 2 == 0

residual = hidden_states
Expand Down Expand Up @@ -479,7 +486,7 @@ def pre_patch():
Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference)
PeftModelForCausalLM .forward = PeftModel_fast_forward
fix_prepare_inputs_for_generation(Gemma2ForCausalLM)

# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
Expand Down
19 changes: 11 additions & 8 deletions unsloth/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def GraniteAttention_fast_forward(
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
Expand Down Expand Up @@ -162,7 +162,7 @@ def GraniteAttention_fast_forward(
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
pass

attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
Expand Down Expand Up @@ -256,7 +256,7 @@ def GraniteAttention_fast_forward_inference(
use_sliding_window = False,
position_embeddings : Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):

assert position_embeddings is not None, f"Granite model requires position embeddings to be specified"

Xn = hidden_states
Expand Down Expand Up @@ -326,7 +326,7 @@ def GraniteAttention_fast_forward_inference(
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)

# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
Expand All @@ -349,7 +349,7 @@ def GraniteAttention_fast_forward_inference(

Qn *= self.scaling
A = torch_matmul(Qn, Kn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])

# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched

A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
Expand Down Expand Up @@ -395,11 +395,15 @@ def GraniteModel_fast_forward_inference(
attention_mask = None
pass

position_embeddings = self.model.rotary_emb(hidden_states, position_ids, self.max_seq_length)
position_embeddings = self.model.rotary_emb.get_cached(seq_len = self.max_seq_length, device = hidden_states.device)

next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):

hidden_states, position_ids = move_to_device(
decoder_layer._per_layer_device, hidden_states, position_ids
)

residual = hidden_states
hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states)
hidden_states, present_key_value = GraniteAttention_fast_forward_inference(
Expand Down Expand Up @@ -532,7 +536,7 @@ def post_patch(model, tokenizer):

elif hasattr(module, "short_cos_cached") and \
(module.short_cos_cached.dtype != correct_dtype):

module.short_cos_cached = module.short_cos_cached.to(correct_dtype)
module.short_sin_cached = module.short_sin_cached.to(correct_dtype)
pass
Expand All @@ -547,4 +551,3 @@ def post_patch(model, tokenizer):
return model, tokenizer
pass
pass

Loading