diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 5e406e02..5507bae3 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -320,7 +320,9 @@ def initialize_unsloth_gradient_checkpointing(dtype = None): CPU_BUFFERS = [] CPU_INDEX = 0 - if dtype is None: + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + dtype = torch.float32 + elif dtype is None: if DEVICE_TYPE == "cuda": major_version, minor_version = torch.cuda.get_device_capability() SUPPORTS_BFLOAT16 = (major_version >= 8) @@ -604,6 +606,112 @@ def backward(ctx, *args): pass +class UnslothCheckpointFunction_Float32(torch.autograd.Function): + """Special version for FORCE_FLOAT32 mode without AMP""" + + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + # Same as original but WITHOUT @torch_amp_custom_fwd decorator + ctx.run_function = run_function + ctx.preserve_rng_state = preserve_rng_state + ctx.device_type = _infer_device_type(*args) + + if preserve_rng_state: + ctx.fwd_cpu_state = torch.get_rng_state() + ctx.had_device_in_fwd = False + device_module = _get_device_module(ctx.device_type) + if getattr(device_module, "_initialized", False): + ctx.had_device_in_fwd = True + ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args) + + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + ctx._requires_gradient = False + + for i, arg in enumerate(args): + if torch.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + if arg.requires_grad: + ctx._requires_gradient = True + else: + ctx.inputs.append(arg) + + if ctx._requires_gradient: + ctx.save_for_backward(*tensor_inputs) + + with torch.no_grad(): + outputs = run_function(*args) + + return outputs + + @staticmethod + def backward(ctx, *args): + # Same as original but WITHOUT @torch_amp_custom_bwd decorator + if not ctx._requires_gradient: + return None + + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "When use_reentrant=True, torch.utils.checkpoint is incompatible" + " with .grad() or passing an `inputs` parameter to .backward()." + ) + + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensors + + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + + rng_devices = [] + if ctx.preserve_rng_state and ctx.had_device_in_fwd: + rng_devices = ctx.fwd_devices + + with torch.random.fork_rng( + devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device_type + ): + if ctx.preserve_rng_state: + torch.set_rng_state(ctx.fwd_cpu_state) + if ctx.had_device_in_fwd: + set_device_states(ctx.fwd_devices, ctx.fwd_device_states, device_type=ctx.device_type) + + detached_inputs = [] + for inp in inputs: + if not isinstance(inp, torch.Tensor): + detached_inputs.append(inp) + continue + x = inp.detach() + x.requires_grad = inp.requires_grad + detached_inputs.append(x) + + # NO autocast context here - just pure computation + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(outputs)): + if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: + outputs_with_grad.append(outputs[i]) + args_with_grad.append(args[i]) + + if len(outputs_with_grad) > 0: + torch.autograd.backward(outputs_with_grad, args_with_grad) + + grads = tuple( + inp.grad if isinstance(inp, torch.Tensor) else None + for inp in detached_inputs + ) + + return (None, None) + grads +pass + from torch.utils.checkpoint import ( ContextManager, _DEFAULT_DETERMINISM_MODE, @@ -756,7 +864,11 @@ def unsloth_checkpoint( "Passing `context_fn` or `debug` is only supported when " "use_reentrant=False." ) - return UnslothCheckpointFunction.apply(function, preserve, *args) + # Choose the correct checkpoint function based on FORCE_FLOAT32 + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + return UnslothCheckpointFunction_Float32.apply(function, preserve, *args) + else: + return UnslothCheckpointFunction.apply(function, preserve, *args) else: gen = _checkpoint_without_reentrant_generator( function, preserve, context_fn, determinism_check, debug, *args, **kwargs @@ -774,13 +886,23 @@ def unsloth_checkpoint( def patch_unsloth_smart_gradient_checkpointing(dtype = None): # All Unsloth Zoo code licensed under LGPLv3 - if torch.utils.checkpoint.CheckpointFunction.__name__ != "UnslothCheckpointFunction": + + # Use float32 version if FORCE_FLOAT32 is set + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + # Force dtype to float32 for buffer initialization + initialize_unsloth_gradient_checkpointing(torch.float32) + + # Store the correct checkpoint function + if not hasattr(torch.utils.checkpoint, "_checkpoint_mode"): + torch.utils.checkpoint._checkpoint_mode = "float32" + else: initialize_unsloth_gradient_checkpointing(dtype) - torch.utils.checkpoint._old_CheckpointFunction = torch.utils.checkpoint.CheckpointFunction - torch.utils.checkpoint.CheckpointFunction = UnslothCheckpointFunction + if not hasattr(torch.utils.checkpoint, "_checkpoint_mode"): + torch.utils.checkpoint._checkpoint_mode = "normal" + # Always patch checkpoint function if torch.utils.checkpoint.checkpoint.__name__ != "unsloth_checkpoint": - torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint + torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint.checkpoint torch.utils.checkpoint.checkpoint = unsloth_checkpoint pass @@ -800,7 +922,7 @@ def unpatch_unsloth_smart_gradient_checkpointing(): if (torch.utils.checkpoint.checkpoint.__name__ == "unsloth_checkpoint") and \ hasattr(torch.utils, "_old_checkpoint"): - + torch.utils.checkpoint = torch.utils._old_checkpoint pass diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 1d72440e..c1e2877a 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -242,19 +242,30 @@ def patch_model_and_tokenizer( pass # If we force float32, we first use bfloat16, then downcast to float16 if do_forced_float32: - correct_dtype = torch.float16 - for name, module in model.named_modules(): - if "down_proj" in name or "up_proj" in name or "gate_proj" in name: - exec(f"module.to(torch.float16)") - if "q_proj" in name or "k_proj" in name or "v_proj" in name or "o_proj" in name: - exec(f"module.to(torch.float16)") - if "lm_head" in name or "embed_tokens" in name: - exec(f"module.to(torch.float16)") - if "norm" in name: - exec(f"module.to(torch.float32)") - assert(module.weight.dtype == torch.float32) - torch.cuda.empty_cache() - pass + correct_dtype = torch.float16 + for name, module in model.named_modules(): + if "down_proj" in name or "up_proj" in name or "gate_proj" in name or "fc1" in name or "fc2" in name: + module.to(torch.float16) + if "q_proj" in name or "k_proj" in name or "v_proj" in name or "o_proj" in name or "out_proj" in name: + module.to(torch.float16) + if "lm_head" in name or "embed_tokens" in name: + module.to(torch.float16) + if "embed_tokens" in name or "patch_embedding" in name: + module.to(torch.float16) + if "norm" in name: + module.to(torch.float16) + torch.cuda.empty_cache() + + # Convert any remaining bfloat16 parameters + for name, param in model.named_parameters(): + if param.dtype == torch.bfloat16: + param.data = param.data.to(torch.float16) + + # Also convert buffers (like position embeddings) + for name, buffer in model.named_buffers(): + if buffer.dtype == torch.bfloat16: + buffer.data = buffer.data.to(torch.float16) + pass pass # Correct torch_dtype @@ -279,7 +290,7 @@ def __fix_dtype(config): try: setattr(m, "dtype", correct_dtype) except: pass pass - + # Check all params and patch! for name, module in model.named_modules(): if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): @@ -313,7 +324,7 @@ def __fix_dtype(config): elif hasattr(module, "short_cos_cached") and \ (module.short_cos_cached.dtype != correct_dtype): - + module.short_cos_cached = module.short_cos_cached.to(correct_dtype) module.short_sin_cached = module.short_sin_cached.to(correct_dtype) pass @@ -373,7 +384,7 @@ def __fix_dtype(config): lm_head.weight = old_output_embedding if not is_tied else old_input_embedding lm_head.in_features = lm_head.weight.shape[1] lm_head.out_features = lm_head.weight.shape[0] - + lm_head.weight.requires_grad_(requires_grad) model.set_output_embeddings(lm_head) if hasattr(model, "lm_head"): model.lm_head = lm_head @@ -457,7 +468,7 @@ def check_conversion_mappings(model, current_key_name_str, skip_modules): if hasattr(model_root_cls, "_checkpoint_conversion_mapping") and len(model_root_cls._checkpoint_conversion_mapping) > 0: # if this is true, then it means that we must be on transformers >=4.52.0 because conversion_mappings was added in 4.52.0 # we cant know if the skip module naming convention is new or old - # but if we are supposed to skip this current_key_name_str, and it didn't pass + # but if we are supposed to skip this current_key_name_str, and it didn't pass # (current_key_name_str in quantization_config.llm_int8_skip_modules) # then new transformers + new module hierarchy means it should not be skipped, ie no BC check needed # and new transformers + old module hierarchy means we still need to check to skip @@ -605,9 +616,9 @@ def visit_Assign(self, node: ast.Assign): def add_score_code(match): indentation = match.group(1) # Captured indentation line_content = match.group(2) # The line 'current_key_name.append(name)' - + indented_breakpoint_code = "\n".join([f"{indentation}{line}" for line in score_code.splitlines()]) - + return f"{indentation}{line_content}\n{indented_breakpoint_code}" source = re.sub(pattern, add_score_code, source, flags=re.MULTILINE) diff --git a/unsloth_zoo/temporary_patches/gemma.py b/unsloth_zoo/temporary_patches/gemma.py index db34c2ff..efb45e26 100644 --- a/unsloth_zoo/temporary_patches/gemma.py +++ b/unsloth_zoo/temporary_patches/gemma.py @@ -100,9 +100,9 @@ def __call__( raise ValueError( f"Prompt contained {len(image_indexes)} image tokens but received {len(images_for_item)} images." ) - + iterable_num_crops = num_crops_for_item - + if isinstance(num_crops_for_item, int): if len(image_indexes) > 0: iterable_num_crops = [num_crops_for_item] + [0] * (len(image_indexes) - 1) @@ -155,22 +155,107 @@ def __call__( pass TEMPORARY_PATCHES.append(patch_Gemma3Processor) -def patch_Gemma3ForConditionalGeneration(): +def patch_Gemma3ForConditionalGeneration_forward_router(): + try: + import transformers.models.gemma3.modeling_gemma3 + except: + return + from transformers.models.gemma3.modeling_gemma3 import ( + Gemma3CausalLMOutputWithPast, + Cache, + ) + + def forward_router( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: + # Routing logic + is_text_only = ( + pixel_values is None and + token_type_ids is None and + (input_ids is not None or inputs_embeds is not None) + ) + + if is_text_only: + return self.forward_llm( + input_ids, + pixel_values, + attention_mask, + position_ids, + past_key_values, + token_type_ids, + cache_position, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + logits_to_keep, + **lm_kwargs) + else: + return self.forward_multimodal( + input_ids, + pixel_values, + attention_mask, + position_ids, + past_key_values, + token_type_ids, + cache_position, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + logits_to_keep, + **lm_kwargs) + + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward).parameters + new_keys = inspect.signature(forward_router).parameters + if old_keys != new_keys: + if UNSLOTH_ENABLE_LOGGING: + print("Unsloth: Failed patching Gemma3ForConditionalGeneration") + else: + transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward_router + return +pass +TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration_forward_router) + + +def patch_Gemma3ForConditionalGeneration_forward_multimodal(): try: import transformers.models.gemma3.modeling_gemma3 except: return from transformers.models.gemma3.modeling_gemma3 import ( - HybridCache, - Gemma3CausalLMOutputWithPast, - logger, - is_torchdynamo_compiling, - Cache, - ) - def forward( + HybridCache, + Gemma3CausalLMOutputWithPast, + BaseModelOutputWithPast, + CausalLMOutputWithPast, + logger, + is_torchdynamo_compiling, + Cache, + ) + + def forward_multimodal( self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, @@ -181,77 +266,365 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **lm_kwargs, ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + is_training = token_type_ids is not None and labels is not None + + # Replace image id woth PAD if the image token if OOV, to avoid index-errors + #if input_ids is not None and self.config.image_token_index >= self.vocab_size: + if input_ids is not None and self.config.image_token_index >= self.config.text_config.vocab_size: + special_image_mask = input_ids == self.config.image_token_index + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + + # Merge text and images + if pixel_values is not None: + #image_features = self.get_image_features(pixel_values) + image_features = self.model.get_image_features(pixel_values) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # mask out pad-token-ids in labels for BC + if labels is not None and self.model.pad_token_id in labels: + logger.warning_once( + "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " + "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", + ) + labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) + + causal_mask = self.model._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training + ) + if labels is not None and attention_mask is not None: + attention_mask = attention_mask.to(device = labels.device) + labels[attention_mask == 0] = -100 + pass + outputs = self.model( + labels=labels, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + hidden_states = outputs[0] + # Ensure dtype compatibility with lm_head + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + + # Check if we're on a float16 machine with forced float32 + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + # Compute in float32 to avoid overflow, then convert back + hidden_states_slice = hidden_states[:, slice_indices, :].to(torch.float32) + lm_head_weight = self.lm_head.weight.to(torch.float32) + logits = torch.nn.functional.linear(hidden_states_slice, lm_head_weight, self.lm_head.bias) + if labels is None: + # If no loss computation, convert back to original dtype + logits = logits.to(self.lm_head.weight.dtype) + else: + # Normal path - ensure dtype compatibility + hidden_states = hidden_states.to(self.lm_head.weight.dtype) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + + return Gemma3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + pass + + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward).parameters + new_keys = inspect.signature(forward_multimodal).parameters + if old_keys != new_keys: + if UNSLOTH_ENABLE_LOGGING: + print("Unsloth: Failed patching Gemma3ForConditionalGeneration Multimodal Forward") + else: + transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward_multimodal = forward_multimodal + return +pass +TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration_forward_multimodal) + + +def patch_Gemma3ForConditionalGeneration_forward_llm(): + try: + import transformers.models.gemma3.modeling_gemma3 + except: + return + from transformers.models.gemma3.modeling_gemma3 import ( + Gemma3CausalLMOutputWithPast, + Cache, + ) + + def forward_llm( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + + # Direct route through language_model + outputs = self.model.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **lm_kwargs, ) - is_training = token_type_ids is not None and labels is not None + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - # Replace image id woth PAD if the image token if OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_index >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_index - llm_input_ids = input_ids.clone() - llm_input_ids[special_image_mask] = 0 + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + # Compute in float32 to avoid overflow, then convert back + hidden_states_slice = hidden_states[:, slice_indices, :].to(torch.float32) + lm_head_weight = self.lm_head.weight.to(torch.float32) + logits = torch.nn.functional.linear(hidden_states_slice, lm_head_weight, self.lm_head.bias) + if labels is None: + # If no loss computation, convert back to original dtype + logits = logits.to(self.lm_head.weight.dtype) else: - llm_input_ids = input_ids + # Normal path - ensure dtype compatibility + hidden_states = hidden_states.to(self.lm_head.weight.dtype) + logits = self.lm_head(hidden_states[:, slice_indices, :]) - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + # Apply softcapping if configured + if self.config.text_config.final_logit_softcapping is not None: + logits = logits / self.config.text_config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.text_config.final_logit_softcapping + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config.text_config.vocab_size, **lm_kwargs) + return Gemma3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=None, # No images in text-only mode + ) + pass + + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward).parameters + new_keys = inspect.signature(forward_llm).parameters + if old_keys != new_keys: + if UNSLOTH_ENABLE_LOGGING: + print("Unsloth: Failed patching Gemma3ForConditionalGeneration language forward") + else: + transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward_llm = forward_llm +pass +TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration_forward_llm) - # Merge text and images - if pixel_values is not None: - image_features = self.get_image_features(pixel_values) - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) - ) - else: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) +def patch_Gemma3ForCausalLM_forward_router(): + try: + import transformers.models.gemma3.modeling_gemma3 + except: + return + from transformers.models.gemma3.modeling_gemma3 import ( + CausalLMOutputWithPast, + HybridCache + ) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." + def forward_router( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> CausalLMOutputWithPast: + + # Routing logic + is_grpo_training = ( + os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1" and + self.training and + labels is None + ) + + + if is_grpo_training: + return self.grpo_forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + cache_position, + logits_to_keep, + **loss_kwargs + ) + else: + return self.original_forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + cache_position, + logits_to_keep, + **loss_kwargs ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - # mask out pad-token-ids in labels for BC - if labels is not None and self.pad_token_id in labels: + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM.forward).parameters + new_keys = inspect.signature(forward_router).parameters + if old_keys != new_keys: + if UNSLOTH_ENABLE_LOGGING: + print("Unsloth: Failed patching Gemma3ForCausalLM") + else: + transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM.forward = forward_router + return +pass +TEMPORARY_PATCHES.append(patch_Gemma3ForCausalLM_forward_router) + + +def patch_Gemma3ForCausalLM(): + try: + import transformers.models.gemma3.modeling_gemma3 + except: + return + from transformers.models.gemma3.modeling_gemma3 import ( + CausalLMOutputWithPast, + logger, + HybridCache + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> CausalLMOutputWithPast: + + if self.training and self.config._attn_implementation != "eager": logger.warning_once( - "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " - "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", + "It is strongly recommended to train Gemma3 models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." ) - labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) - - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - if labels is not None and attention_mask is not None: - attention_mask = attention_mask.to(device = labels.device) - labels[attention_mask == 0] = -100 - pass - outputs = self.language_model( - labels=labels, - attention_mask=causal_mask, + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -259,154 +632,158 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, - logits_to_keep=logits_to_keep, - **lm_kwargs, + **loss_kwargs, ) - labels = None - # We NEVER ENTER if labels is not None: since we already accounted for it + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + # Compute in float32 to avoid overflow, then convert back + hidden_states_slice = hidden_states[:, slice_indices, :].to(torch.float32) + lm_head_weight = self.lm_head.weight.to(torch.float32) + lm_head_bias = self.lm_head.bias.to(torch.float32) if self.lm_head.bias is not None else None + logits = torch.nn.functional.linear(hidden_states_slice, lm_head_weight, lm_head_bias) + if labels is None: + # If no loss computation, convert back to original dtype + logits = logits.to(self.lm_head.weight.dtype) + else: + # ensure dtype compatibility + hidden_states = hidden_states.to(self.lm_head.weight.dtype) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + # Apply final logit softcapping if configured + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping - logits = outputs.logits loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - shift_logits = logits[..., :-1, :] - shift_labels = labels[..., 1:] - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) - shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() - else: - shift_logits = shift_logits.contiguous() - shift_labels = shift_labels.contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - - flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) - flat_labels = shift_labels.view(-1).to(shift_logits.device) - loss = loss_fct(flat_logits, flat_labels) - loss = outputs.loss + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - return Gemma3CausalLMOutputWithPast( + return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, ) - pass - old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward).parameters + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: - if UNSLOTH_ENABLE_LOGGING: - print("Unsloth: Failed patching Gemma3ForConditionalGeneration.forward v1") + print("Unsloth: Failed to patch Gemma3ForCausalLM compatibility.") else: - transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward + transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM.original_forward = forward + return +pass +TEMPORARY_PATCHES.append(patch_Gemma3ForCausalLM) + + +def patch_Gemma3ForCausalLMGRPO(): + try: + import transformers.models.gemma3.modeling_gemma3 + except: return + from transformers.models.gemma3.modeling_gemma3 import ( + CausalLMOutputWithPast, + logger, + HybridCache + ) def forward( self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, - token_type_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, - ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: + **loss_kwargs, + ) -> CausalLMOutputWithPast: + + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train Gemma3 models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if labels is not None and attention_mask is not None: - attention_mask = attention_mask.to(device = labels.device) - labels[attention_mask == 0] = -100 - pass + outputs = self.model( input_ids=input_ids, - pixel_values=pixel_values, - token_type_ids=token_type_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, - **lm_kwargs, + **loss_kwargs, ) - labels = None - # We NEVER ENTER if labels is not None: since we already accounted for it - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) + hidden_states = outputs.last_hidden_state - loss = None - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - shift_logits = logits[..., :-1, :] - shift_labels = labels[..., 1:] - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) - shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() - shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() - else: - shift_logits = shift_logits.contiguous() - shift_labels = shift_labels.contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - - flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) - flat_labels = shift_labels.view(-1).to(shift_logits.device) - loss = loss_fct(flat_logits, flat_labels) - loss = outputs.loss - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + # Handle logits_to_keep parameter first + if isinstance(logits_to_keep, int) and logits_to_keep > 0: + hidden_states = hidden_states[:, -logits_to_keep:, :] - return Gemma3CausalLMOutputWithPast( - loss=loss, - logits=logits, + # Convert hidden states to logits using lm_head + # Handle UNSLOTH_FORCE_FLOAT32 for the projection + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + hidden_states_fp32 = hidden_states.to(torch.float32) + lm_head_weight = self.lm_head.weight.to(torch.float32) + lm_head_bias = self.lm_head.bias.to(torch.float32) if self.lm_head.bias is not None else None + logits = torch.nn.functional.linear(hidden_states_fp32, lm_head_weight, lm_head_bias) + # Keep in float32 for GRPO + else: + hidden_states = hidden_states.to(self.lm_head.weight.dtype) + logits = self.lm_head(hidden_states) + # Convert to float32 for GRPO compatibility + logits = logits.to(torch.float32) + + # Apply final logit softcapping if configured + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + # GRPO expects logits without the last position (next token prediction) + # Only remove last position if we have more than 1 position + if logits.dim() == 3 and logits.shape[1] > 1: # (batch, seq_len, vocab_size) + # Remove last position for GRPO compatibility + logits = logits[:, :-1, :] + + # Ensure contiguous memory layout for torch.compile compatibility + logits = logits.contiguous() + + return CausalLMOutputWithPast( + loss=None, + logits=logits, # Return actual logits for GRPO past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, ) - pass - old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward).parameters + # Apply the patch + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: - if UNSLOTH_ENABLE_LOGGING: - print("Unsloth: Failed patching Gemma3ForConditionalGeneration.forward v2") + print("Unsloth: Failed to patch Gemma3ForCausalLM for GRPO compatibility.") else: - transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward + transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM.grpo_forward = forward + return pass -TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration) - +TEMPORARY_PATCHES.append(patch_Gemma3ForCausalLMGRPO) def patch_Gemma3ForConditionalGeneration_causal_mask(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return @@ -536,16 +913,31 @@ def patch_Gemma3RMSNorm(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return - def forward(self, x): - x = x.to(torch.float32) - output = x * torch.rsqrt(x.square().mean(-1, keepdim = True) + self.eps) - return output * (1.0 + self.weight.float()) + + original_rmsnorm_forward = transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward + + def forward(self, x): # x can be fp32 (from embeddings) or fp16 (from MLP/Attn) + # Internals in fp32 + x_fp32 = x.to(torch.float32) + variance = x_fp32.pow(2).mean(-1, keepdim=True) + hidden_states_fp32 = x_fp32 * torch.rsqrt(variance + self.eps) + + # self.weight is bf16 (from vision.py loading if UNSLOTH_FORCE_FLOAT32="1") + # So, cast self.weight to fp32 for the (1.0 + weight) operation + output_fp32 = hidden_states_fp32 * (1.0 + self.weight.to(torch.float32)) + + # Clamp to fp16 range before casting back to fp16 + fp16_max = torch.finfo(torch.float16).max + fp16_min = torch.finfo(torch.float16).min + clamped_output_fp32 = torch.clamp(output_fp32, min=fp16_min, max=fp16_max) + + return clamped_output_fp32.to(torch.float16) # Output fp16 pass - old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward).parameters + old_keys = inspect.signature(original_rmsnorm_forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: if UNSLOTH_ENABLE_LOGGING: - print("Unsloth: Failed to patch Gemma3RMSNorm.") + print("Unsloth: Failed to patch Gemma3RMSNorm (adjusted). Signature mismatch.") else: forward = torch.compile(forward, fullgraph = True, dynamic = True, options = torch_compile_options) transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward = forward @@ -558,40 +950,53 @@ def patch_Gemma3MLP(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return - def forward(self, x): - x = x.to(torch.float16) - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj.to(torch.float32) + + original_mlp_forward = transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward + + def forward(self, x): # x is fp16 from RMSNorm + gate_proj_out = self.gate_proj(x) + up_proj_out = self.up_proj(x) + + # Upcast to fp32 + gate_proj_fp32 = gate_proj_out.to(torch.float32) + up_proj_fp32 = up_proj_out.to(torch.float32) + activated_fp32 = self.act_fn(gate_proj_fp32) # Activation in fp32 + intermediate_fp32 = activated_fp32 * up_proj_fp32 # Product in fp32 + + # Downcast and down_proj + intermediate_fp16 = intermediate_fp32.to(torch.float16) + down_proj_out = self.down_proj(intermediate_fp16) + return down_proj_out pass - old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward).parameters + + old_keys = inspect.signature(original_mlp_forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: - print("Unsloth: Failed to patch Gemma3MLP.") + if UNSLOTH_ENABLE_LOGGING: print("Unsloth: Failed to patch Gemma3MLP") else: forward = torch.compile(forward, fullgraph = False, dynamic = True, options = torch_compile_options) transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward = forward + return pass TEMPORARY_PATCHES.append(patch_Gemma3MLP) def patch_Gemma3Attention(): - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": - downcast_dtype = torch.float16 - else: - downcast_dtype = torch.bfloat16 - try: import transformers.models.gemma3.modeling_gemma3 - except: return + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return + + try: + import transformers.models.gemma3.modeling_gemma3 + except: + return + import typing from transformers.models.gemma3.modeling_gemma3 import ( - Cache, - Unpack, - FlashAttentionKwargs, - apply_rotary_pos_emb, - ALL_ATTENTION_FUNCTIONS, - logger, - eager_attention_forward, + Cache, Unpack, FlashAttentionKwargs, apply_rotary_pos_emb, ) + scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention + original_hf_attention_forward = transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward + def forward( self, hidden_states: torch.Tensor, @@ -600,80 +1005,110 @@ def forward( past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - hidden_states = hidden_states.to(downcast_dtype) - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + bsz, q_len, _ = hidden_states.shape + input_shape = hidden_states.shape[:-1] # For reshaping o_proj output later - query_states = self.q_norm(query_states) - key_states = self.k_norm(key_states) + # Determine head shapes + # Assuming these attributes are standard for Gemma3Attention + # If not, they might come from self.config + num_heads = getattr(self, "num_heads", self.config.num_attention_heads) + num_key_value_heads = getattr(self, "num_key_value_heads", self.config.num_key_value_heads) + head_dim = self.head_dim + + # For projections view: (bsz, q_len, num_specific_heads, head_dim) + query_hidden_shape = (bsz, q_len, num_heads, head_dim) + kv_hidden_shape = (bsz, q_len, num_key_value_heads, head_dim) + + # 1. Projections (q, k, v) in fp16 + # hidden_states is already fp16. Weights of q_proj, k_proj, v_proj are fp16. + query_states_fp16 = self.q_proj(hidden_states) # output fp16 + key_states_fp16 = self.k_proj(hidden_states) # output fp16 + value_states_fp16 = self.v_proj(hidden_states) # output fp16 + + # 2. Upcast Q, K, V for norm and RoPE, and then transpose for attention + # (bsz, num_specific_heads, q_len, head_dim) + query_states_fp32 = query_states_fp16.view(query_hidden_shape).to(torch.float32).transpose(1, 2) + key_states_fp32 = key_states_fp16.view(kv_hidden_shape).to(torch.float32).transpose(1, 2) + value_states_fp32 = value_states_fp16.view(kv_hidden_shape).to(torch.float32).transpose(1, 2) # V for attention also fp32 + + # 3. Normalization (q_norm, k_norm are RMSNorms) + query_norm_out_fp16 = self.q_norm(query_states_fp32) + key_norm_out_fp16 = self.k_norm(key_states_fp32) + + query_states_fp32 = query_norm_out_fp16.to(torch.float32) + key_states_fp32 = key_norm_out_fp16.to(torch.float32) + + # 4. Rotary Positional Embeddings in fp32 + if not (isinstance(position_embeddings, tuple) and len(position_embeddings) == 2): + raise ValueError("Position embeddings not provided as (cos, sin) tuple to Gemma3Attention") cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + cos_fp32 = cos.to(torch.float32) + sin_fp32 = sin.to(torch.float32) + query_states_fp32, key_states_fp32 = apply_rotary_pos_emb(query_states_fp32, key_states_fp32, cos_fp32, sin_fp32) + # 5. KV Cache update (using fp32 K, V) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, - "cache_position": cache_position, - "sliding_window": self.sliding_window, + "sin": sin_fp32, "cos": cos_fp32, "cache_position": cache_position } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # Here we need to slice as we use a static cache by default, but FA2 does not support it - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": - seq_len = attention_mask.shape[-1] - key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] - - # attention_interface: Callable = eager_attention_forward - # if self.config._attn_implementation != "eager": - # if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - # logger.warning_once( - # "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " - # "Falling back to eager attention. This warning can be removed using the argument " - # '`attn_implementation="eager"` when loading the model.' - # ) - # else: - # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - # attn_output, attn_weights = attention_interface( - # self, - # query_states.to(downcast_dtype), - # key_states.to(downcast_dtype), - # value_states.to(downcast_dtype), - # attention_mask.to(downcast_dtype), - # dropout=self.attention_dropout if self.training else 0.0, - # scaling=self.scaling, - # sliding_window=self.sliding_window, - # **kwargs, - # ) - attn_output = scaled_dot_product_attention( - query_states.to(downcast_dtype), - key_states.to(downcast_dtype), - value_states.to(downcast_dtype), - attn_mask=attention_mask.to(downcast_dtype), + # Add sliding_window if the attribute exists (common in newer models) + if hasattr(self, "sliding_window") and self.sliding_window is not None: + cache_kwargs["sliding_window"] = self.sliding_window + key_states_fp32, value_states_fp32 = past_key_value.update( + key_states_fp32, value_states_fp32, self.layer_idx, cache_kwargs + ) + + # 6. Core Attention mechanism (SDPA) in fp32 + attn_mask_for_sdpa = attention_mask + if attn_mask_for_sdpa is not None: + attn_mask_for_sdpa = attn_mask_for_sdpa.to(torch.float32) + + output_attentions = kwargs.get("output_attentions", False) + + + attn_output_fp32 = scaled_dot_product_attention( + query_states_fp32, + key_states_fp32, + value_states_fp32, + attn_mask=attn_mask_for_sdpa, dropout_p=self.attention_dropout if self.training else 0.0, - scale=self.scaling, + # is_causal=False, # Mask handles causality. If mask is None and q_len > 1, this might be true. + # Gemma3's _update_causal_mask provides the explicit mask. + scale=getattr(self, "scaling", None), # Use self.scaling if defined, else SDPA default enable_gqa=getattr(self, "num_key_value_groups", 1) != 1, - ).transpose(1, 2) + ) + attn_weights = None # Defaulting to None + + # 7. Reshape and Downcast for Output Projection + # attn_output_fp32 from SDPA is (bsz, num_heads, q_len, head_dim) + attn_output_fp32 = attn_output_fp32.transpose(1, 2).contiguous() + + # Reshape to (bsz, q_len, num_query_heads * head_dim) which is (bsz, q_len, model_hidden_size) + # Using -1 for the last dimension is robust and aligns with your original example. + attn_output_fp32 = attn_output_fp32.reshape(bsz, q_len, -1) # REVISED FIX - attn_output = attn_output.reshape(*input_shape, -1)#.contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, None + attn_output_fp16 = attn_output_fp32.to(torch.float16) + + # 8. Output Projection (o_proj) in fp16 + attn_output_projected = self.o_proj(attn_output_fp16) # fp16 output + + return attn_output_projected, attn_weights # 3-tuple return pass - old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward).parameters - new_keys = inspect.signature(forward).parameters - if old_keys != new_keys: + + old_keys_sig = inspect.signature(original_hf_attention_forward) + new_keys_sig = inspect.signature(forward) + + if old_keys_sig.parameters != new_keys_sig.parameters or old_keys_sig.return_annotation != new_keys_sig.return_annotation: if UNSLOTH_ENABLE_LOGGING: - print("Unsloth: Failed to patch Gemma3Attention.") + print("Unsloth: Failed to patch Gemma3Attention (adjusted for signature matching). Signature mismatch with original HF method.") else: forward = torch.compiler.disable(forward, recursive = False) transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward = forward + if UNSLOTH_ENABLE_LOGGING: + print("Unsloth: Patched Gemma3Attention.forward (adjusted for signature matching, output fp16).") return pass TEMPORARY_PATCHES.append(patch_Gemma3Attention)