Skip to content

Conversation

@rolandtannous
Copy link
Collaborator

@rolandtannous rolandtannous commented Jun 21, 2025

Problem

Gemma 3 models suffer from several critical issues on both bfloat16 and float16-only hardware:

  • Transformers 4.52.3 architecture changes: Signature change leads to patches failing to apply. new scheme also introduces breaking changes leading to failed loss computations (Gemma3LLMCausalOutput does not return logits or loss in new Transformers)
  • Routing logic errors: Labels set to be always None, never activating loss calculations for training runs
  • Gradient Explosions: Activations exceed float16's maximum value (65,504) on float-16 only hardware, causing infinite or NaN gradients
  • Tensor Type Mismatches: Inconsistent dtype handling between different model components
  • AMP Incompatibility: PyTorch's automatic mixed precision forces incompatible dtype combinations
  • Vision Layer Interference: Multimodal architecture causes training instability even for text-only tasks
  • GRPO Training Failures: Multiple tensor incompatibilities in reasoning-based training workflows

Solves Issues and PRs:

Solution

Precision handling on float16-only hardware:

Restructured, rewrote and realigned gemma3 model component patches and implementation with initial approach described in gemma3-fixes blogpost

  • Float32 for Stability: All intermediate computations, activation functions, normalizations, and gradient-sensitive operations execute in float32. This includes attention score computation, MLP intermediate results, embedding scaling, and RMSNorm calculations where overflow is most likely.
  • Float16 for Efficiency: Matrix multiplications in linear projections (q_proj, k_proj, v_proj, o_proj) remain in float16 to leverage tensor cores. Model weights stay in float16 for memory efficiency, and inter-layer activations are passed in float16 to minimize bandwidth.
  • Gradient Checkpointing Enhancements: When UNSLOTH_FORCE_FLOAT32=1, the code calls a newly introduced float32 compatible checkpoint functions that avoid AMP decorators entirely, preventing forced dtype conversions that cause instability and dtype mismatch errors. Also introduced Buffer management that coordinates CPU/GPU memory with dtype-aware allocation, ensuring intermediate activations are stored in the correct precision for recomputation during backward passes.

GRPO Training Issues:

Implemented specialized forward routing and tensor handling to resolve reasoning training incompatibilities in Gemma3 models.

  • Dedicated GRPO Forward Path: Created grpo_forward method that handles logits-to-keep parameter processing, removes final sequence positions for next-token prediction compatibility, and ensures contiguous tensor layouts required for scatter operations during gradient computation.
  • Float32 Logits Enforcement: Modified output projection to guarantee float32 logits regardless of underlying model precision, preventing numerical instability in policy gradient computations while maintaining efficient float16 internal processing.
  • Environment-Based Detection: Added UNSLOTH_RETURN_HIDDEN_STATES flag detection combined with training mode checks to automatically route GRPO workflows through the specialized path, eliminating manual configuration requirements.

Language Model Route for Gemma3ForConditionalGeneration

  • Clean Activation Pipeline: The language-only route maintains a pristine forward pass without vision-related masking, token type complexity, or image feature integration, ensuring optimal numerical stability for pure text training scenarios.
  • Architectural Flexibility: Preserved full multimodal capabilities through forward_multimodal while providing the performance and stability benefits of dedicated language model processing when images aren't present, eliminating the need for separate model variants.
  • Training Stability Enhancement: Implemented detection of text-only scenarios (absence of pixel_values and token_type_ids) to automatically engage the streamlined language-only forward method, eliminating NaN gradient norms and training instability caused by dormant vision components interfering with text-only gradient flow in our current patch implementation.

Tests:

Tested both on bfloat16 (H100) and float16-only hardware (vanilla T4 and T4 with High RAM)

Used perplexity and ocr vision wer/cer metrics testing to compare model performance throughout its lifecycle. The findings were as follows:

Perplexity tests for Gemma3ForCausalLM

Gemma-3-1b Perplexity tests

Precision | Base Model | Peft Model | Merged load 4bit | Merged load 8 bit | Merge load-16bit
Bfloat16 | 57.3395 | 13.343345 | 13.503703 | 12.597788 | 12.597788
Float16 | 81.5141 | 18.968971 | 19.462515 | 17.541702 | 17.541702

Vision OCR tests on Gemma3ForConditionalGeneration

    Gemma-3-4b OCR vision tests    
Precision Base Model Peft Model Merged load 4bit Merged load 16 bit
Bfloat16 WER: 0.8584CER: 0.6946 WER: 0.0440CER: 0.0088 WER: 0.0540CER: 0.0103 WER: 0.0492CER: 0.0108
Float16 WER: 0.8846CER: 0.7211 WER: 0.0475CER: 0.0085 WER: 0.0578CER: 0.0113 WER: 0.0437CER: 0.0082
  • figures depend on training args. smaller is better for wer and cer metrics |   |   |   |  

Ran the following E2E tests using official unsloth notebooks successfully on both architectures:

Associated PRs

Also needs unsloth PR#2780(Gemma3ForCausalLM object has no attribute self.llm) to solve GRPO attribute error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants