Skip to content

Conversation

@vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented Jul 16, 2025

What does this PR do?

  • Added training script for gemma example based on lm1b_nnx example:
    • Added training code from l1mb_nnx
    • Added support distributed training via sharding, tested on 2 GPUs and TPU VM with 4 devices.
    • Added mixed precision config
    • Add support for multiple samples per sequence as in lm1b (nothing to modify in the attention layer, just use appropriate attention mask and provide shifted data)

Addresses #4740

  • 2 GPUs training logs, gemma3-1b model config: link (1000 iters)
  • TPU v4-8 training logs, gemma3-1b config: link (40000 iters)

@vfdev-5 vfdev-5 force-pushed the add-train-script-gemma-example branch from 0c16694 to 2eb7baa Compare July 17, 2025 00:12
@vfdev-5 vfdev-5 marked this pull request as ready for review July 17, 2025 08:04
@vfdev-5 vfdev-5 requested a review from IvyZX July 17, 2025 08:04
@copybara-service copybara-service bot merged commit 97f4a49 into main Jul 22, 2025
19 of 20 checks passed
@copybara-service copybara-service bot deleted the add-train-script-gemma-example branch July 22, 2025 20:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants