Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions examples/gemma/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@

## Language modeling
Trains Gemma model on the One Billion Word Benchmark (lm1b; Chelba *et al.*, 2013).

This example is based on `lm1b_nnx` example script and similarly uses linear learning rate warmup and inverse square root learning rate schedule.


### Requirements

* TensorFlow datasets `lm1b` need to be downloaded and prepared (see below).
A sentencepiece tokenizer vocabulary will be automatically generated
and saved on each training run.
* This example additionally depends on the `sentencepiece` and `tensorflow-text` packages.

### Downloading the LM1B Datasets

We recommend downloading and preparing the TFDS datasets beforehand. You can download and prepare LM1B datasets using TFDS directly: `python -m tensorflow_datasets.scripts.download_and_prepare --datasets=lm1b`.

#### Using Cloud Storage FUSE for TPUs

For Cloud TPUs, we recommend using a cheap standard instance and saving the prepared TFDS
data on a storage bucket, from where it can be mounted to the TPU VM using [Cloud Storage FUSE](https://cloud.google.com/storage/docs/cloud-storage-fuse/quickstart-mount-bucket).

##### Copy the preprocessed dataset to the Cloud Storage

We assume that the dataset was downloaded and prepared. We also assume we have configured `gcloud` CLI. The following commands helps to setup the storage and copy the dataset:

```bash
# Install gcsfuse CLI
export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s`
# For example, GCSFUSE_REPO=gcsfuse-noble for Ubuntu 24.04

echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
sudo apt-get update
sudo apt-get install -y fuse gcsfuse --no-install-recommends

gcsfuse -v
# gcsfuse version 2.12.2 (Go version go1.24.0)
```

Let's get where LM1B dataset was locally stored:
```bash
python -c "import tensorflow_datasets as tfds; b=tfds.builder('lm1b'); print(b.info.data_dir)"
# For example: /home/user/tensorflow_datasets/lm1b/1.1.0
```

Let's create a GCS bucket for the dataset and link the bucket to a local folder. We choose the bucket name "flax-lm1b-tfdataset" but this can be changed.
```bash
gcloud storage buckets create gs://flax-lm1b-tfdataset

mkdir -p $HOME/data
gcsfuse flax-lm1b-tfdataset $HOME/data
```

Now let's copy the data to the bucket:
```bash
# Let's assume that prepared dataset is at $HOME/tensorflow_datasets/lm1b/
cp -R $HOME/tensorflow_datasets/lm1b $HOME/data
```

##### Setup the dataset on TPU VM

We previously have choosen the bucket name "flax-lm1b-tfdataset" where stored the dataset, adapt this name to your situation.

```bash
# On the TPU VM
gcsfuse flax-lm1b-tfdataset $HOME/tensorflow_datasets

ls $HOME/tensorflow_datasets/lm1b/1.1.0/
```

### How to run on GPU(s)

Install Jax with CUDA support, Flax and the example dependencies with the following command:
```bash
pip install jax[cuda12]
# Check whether GPUs are available:
# python3 -c "import jax; print(jax.devices())"

git clone --depth=1 --branch=main https://github.com/google/flax
cd flax
pip install -e .
cd examples/gemma
pip install -r requirements.txt
```

Start the training:

- train a small transformer model:
```bash
python3 main.py --workdir=$HOME/logs/small_gemma_lm1b --config=configs/small.py
```

- train Gemma3-4B model:
```bash
python3 main.py --workdir=$HOME/logs/gemma3-4b_lm1b --config=configs/gemma3_4b.py
```

To monitor the trainings with the TensorBoard:
```bash
tensorboard --logdir=$HOME/logs
```


### How to run on Cloud TPUs

Setup the TPU VM and install the Flax dependencies on it as described
[here](https://cloud.google.com/tpu/docs/jax-pods) for creating pod slices, or
[here](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm) for a single
v4-8 TPU.


First create a single TPUv4-8 VM and connect to it (you can find more detailed
instructions [here](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm)):

```bash
ZONE=us-central1-a
TPU_TYPE=v4-8
TPU_NAME=$USER-flax-gemma-lm1b
gcloud compute tpus tpu-vm create $TPU_NAME \
--zone $ZONE \
--accelerator-type $TPU_TYPE \
--version tpu-ubuntu2204-base

gcloud compute tpus tpu-vm ssh $TPU_NAME --zone $ZONE -- \
-L 6006:localhost:6006
```

When connected install JAX:

```bash
pip install "jax[tpu]>=0.2.16" \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```

Then install Flax + the example dependencies:

```bash
git clone --depth=1 --branch=main https://github.com/google/flax
cd flax
pip install -e .
cd examples/gemma
pip install -r requirements.txt
```

In case of errors when installing example dependencies, try to upgrade existing `pip` package and downgrade `setuptools` and repeat the installation command
```bash
# Optionally
# pip install -U pip
# pip install -U "setuptools<70"
# pip install -r requirements.txt
```

And finally start the training:

```bash
python3 main.py --workdir=$HOME/logs/gemma_lm1b_256 --config.per_device_batch_size=32
```

Note that you might want to set `TFDS_DATA_DIR` as explained below. You probably
also want to start the long-running command above in a `tmux` session and start
some monitoring in a separate pane (note that we forwarded port 6006 locally
above):

```bash
tensorboard --logdir=$HOME/logs
```
135 changes: 135 additions & 0 deletions examples/gemma/configs/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Default Hyperparameter configuration."""

import dataclasses

from train import MeshRules, TrainConfig


@dataclasses.dataclass(unsafe_hash=True)
class Config:
# Path to load or store sentencepiece vocab file.
vocab_path: str | None = None
# Vocabulary size if `vocab_path` is not given.
vocab_size: int = 35_000 # lm1b dataset vocab size: 35913 (Gemma expected vocab size: 262_144)
# Maximum number of characters to use for training.
max_corpus_chars: int = 10**7
# Name of TFDS translation dataset to use.
dataset_name: str = 'lm1b'
# Optional name of TFDS translation dataset to use for evaluation.
eval_dataset_name: str = 'lm1b'
# Optional name of TFDS split to use for evaluation.
eval_split: str = 'test'
# Per device batch size for training.
per_device_batch_size: int = 32
# Per device batch size for training.
eval_per_device_batch_size: int = 32

# Prompt for language model sampling
prompts: tuple[str, ...] = (
'Paris is a the capital',
'Flax is a',
# From train set:
'The shutdown was aimed at creating efficiencies as',
# -> the plant was already operating at its maximum capacity of 3,000 tonnes of cellulose paste per day
'A big theme of this hire is that there are parts of',
# -> our operations that to use a pretty trite phrase , need to be taken to the next level ...

# From test set:
'Because of Bear Stearns , many analysts are',
# -> raising the odds that a 2008 recession could be worse than expected
'Next month , the Brazilian bourse',
# -> opens a London office',
)
# Temperature for top_p sampling.
sampling_temperature: float = 0.0
# Top-p sampling threshold.
sampling_top_p: float = 0.95

# Number of steps to take during training.
num_train_steps: int = 500_000
# Number of steps to take during evaluation.
# Large enough to evaluate all samples: 306_688 / (32 * 8) = 1198
num_eval_steps: int = 2_000
# Number of steps to generate predictions.
# -1 will use the whole eval dataset.
num_predict_steps: int = 50
# Base learning rate.
learning_rate: float = 0.0016
# Linear learning rate warmup.
warmup_steps: int = 1000
# Cross entropy loss label smoothing.
label_smoothing: float = 0.0
# Decay factor for AdamW style weight decay.
weight_decay: float = 0.1
# Maximum length cutoff for training examples.
max_target_length: int = 128
# Maximum length cutoff for eval examples.
max_eval_target_length: int = 512

# Gemma transformer name.
# Possible values defined in transformer.TransformerConfig:
# (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, ...)
transformer_name: str | None = "gemma3_1b"
# or alternatively define the model using the dict of parameters
transformer_params: dict | None = None

# Whether to save model checkpoints.
save_checkpoints: bool = True
# Whether to restore from existing model checkpoints.
restore_checkpoints: bool = True
# Save a checkpoint every these number of steps.
checkpoint_every_steps: int = 10_000
# Frequency of eval during training, e.g. every 1_000 steps.
eval_every_steps: int = 5_000
# Use bfloat16 mixed precision training instead of float32.
use_bfloat16: bool = True
# Integer for PRNG random seed.
seed: int = 0

# Parallelism
mesh_axes: tuple[str, ...] = ('data', 'fsdp', 'tensor')
axis_rules: MeshRules = MeshRules(
embed='fsdp',
mlp='tensor',
kv='tensor',
vocab='tensor',
)
data_sharding: tuple[str, ...] = ('data', 'fsdp')

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
# ICI (Inter-Chip Interconnection): A high-speed connection between
# sets of TPU chips, which form the TPU network.
# DCN (Data Center Network): A connection between the TPU networks;
# not as fast as ICI.
# ICI has around 100x the bandwidth of DCN, but it is not a general
# purpose connection, which is why DCN is necessary for scaling to
# extremely large ML models.
dcn_data_parallelism: int = -1
dcn_fsdp_parallelism: int = 1
dcn_tensor_parallelism: int = 1
ici_data_parallelism: int = 1
ici_fsdp_parallelism: int = -1
ici_tensor_parallelism: int = 1


def get_config() -> TrainConfig:
"""Get the default hyperparameter configuration."""
config = Config()
return TrainConfig(**dataclasses.asdict(config))
Loading
Loading