Gemma3 architecture and patch fixes #174
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Problem
Gemma 3 models suffer from several critical issues on both bfloat16 and float16-only hardware:
Solves Issues and PRs:
discord error when trying to run GRPO with Gemma 3. -The size of tensor a (s5) must match the size of tensor b (s2) at non-singleton dimension 1
Solution
Precision handling on float16-only hardware:
Restructured, rewrote and realigned gemma3 model component patches and implementation with initial approach described in gemma3-fixes blogpost
UNSLOTH_FORCE_FLOAT32=1, the code calls a newly introduced float32 compatible checkpoint functions that avoid AMP decorators entirely, preventing forced dtype conversions that cause instability and dtype mismatch errors. Also introduced Buffer management that coordinates CPU/GPU memory with dtype-aware allocation, ensuring intermediate activations are stored in the correct precision for recomputation during backward passes.GRPO Training Issues:
Implemented specialized forward routing and tensor handling to resolve reasoning training incompatibilities in Gemma3 models.
Language Model Route for Gemma3ForConditionalGeneration
Tests:
Tested both on bfloat16 (H100) and float16-only hardware (vanilla T4 and T4 with High RAM)
Used perplexity and ocr vision wer/cer metrics testing to compare model performance throughout its lifecycle. The findings were as follows:
Perplexity tests for Gemma3ForCausalLM
Gemma-3-1b Perplexity tests
Precision | Base Model | Peft Model | Merged load 4bit | Merged load 8 bit | Merge load-16bit
Bfloat16 | 57.3395 | 13.343345 | 13.503703 | 12.597788 | 12.597788
Float16 | 81.5141 | 18.968971 | 19.462515 | 17.541702 | 17.541702
Vision OCR tests on Gemma3ForConditionalGeneration
Ran the following E2E tests using official unsloth notebooks successfully on both architectures:
Associated PRs
Also needs unsloth PR#2780(Gemma3ForCausalLM object has no attribute self.llm) to solve GRPO attribute error