diff --git a/examples/gemma/README.md b/examples/gemma/README.md new file mode 100644 index 000000000..8536bb292 --- /dev/null +++ b/examples/gemma/README.md @@ -0,0 +1,168 @@ + +## Language modeling +Trains Gemma model on the One Billion Word Benchmark (lm1b; Chelba *et al.*, 2013). + +This example is based on `lm1b_nnx` example script and similarly uses linear learning rate warmup and inverse square root learning rate schedule. + + +### Requirements + +* TensorFlow datasets `lm1b` need to be downloaded and prepared (see below). + A sentencepiece tokenizer vocabulary will be automatically generated + and saved on each training run. +* This example additionally depends on the `sentencepiece` and `tensorflow-text` packages. + +### Downloading the LM1B Datasets + +We recommend downloading and preparing the TFDS datasets beforehand. You can download and prepare LM1B datasets using TFDS directly: `python -m tensorflow_datasets.scripts.download_and_prepare --datasets=lm1b`. + +#### Using Cloud Storage FUSE for TPUs + +For Cloud TPUs, we recommend using a cheap standard instance and saving the prepared TFDS +data on a storage bucket, from where it can be mounted to the TPU VM using [Cloud Storage FUSE](https://cloud.google.com/storage/docs/cloud-storage-fuse/quickstart-mount-bucket). + +##### Copy the preprocessed dataset to the Cloud Storage + +We assume that the dataset was downloaded and prepared. We also assume we have configured `gcloud` CLI. The following commands helps to setup the storage and copy the dataset: + +```bash +# Install gcsfuse CLI +export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s` +# For example, GCSFUSE_REPO=gcsfuse-noble for Ubuntu 24.04 + +echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list +curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add - +sudo apt-get update +sudo apt-get install -y fuse gcsfuse --no-install-recommends + +gcsfuse -v +# gcsfuse version 2.12.2 (Go version go1.24.0) +``` + +Let's get where LM1B dataset was locally stored: +```bash +python -c "import tensorflow_datasets as tfds; b=tfds.builder('lm1b'); print(b.info.data_dir)" +# For example: /home/user/tensorflow_datasets/lm1b/1.1.0 +``` + +Let's create a GCS bucket for the dataset and link the bucket to a local folder. We choose the bucket name "flax-lm1b-tfdataset" but this can be changed. +```bash +gcloud storage buckets create gs://flax-lm1b-tfdataset + +mkdir -p $HOME/data +gcsfuse flax-lm1b-tfdataset $HOME/data +``` + +Now let's copy the data to the bucket: +```bash +# Let's assume that prepared dataset is at $HOME/tensorflow_datasets/lm1b/ +cp -R $HOME/tensorflow_datasets/lm1b $HOME/data +``` + +##### Setup the dataset on TPU VM + +We previously have choosen the bucket name "flax-lm1b-tfdataset" where stored the dataset, adapt this name to your situation. + +```bash +# On the TPU VM +gcsfuse flax-lm1b-tfdataset $HOME/tensorflow_datasets + +ls $HOME/tensorflow_datasets/lm1b/1.1.0/ +``` + +### How to run on GPU(s) + +Install Jax with CUDA support, Flax and the example dependencies with the following command: +```bash +pip install jax[cuda12] +# Check whether GPUs are available: +# python3 -c "import jax; print(jax.devices())" + +git clone --depth=1 --branch=main https://github.com/google/flax +cd flax +pip install -e . +cd examples/gemma +pip install -r requirements.txt +``` + +Start the training: + +- train a small transformer model: +```bash +python3 main.py --workdir=$HOME/logs/small_gemma_lm1b --config=configs/small.py +``` + +- train Gemma3-4B model: +```bash +python3 main.py --workdir=$HOME/logs/gemma3-4b_lm1b --config=configs/gemma3_4b.py +``` + +To monitor the trainings with the TensorBoard: +```bash +tensorboard --logdir=$HOME/logs +``` + + +### How to run on Cloud TPUs + +Setup the TPU VM and install the Flax dependencies on it as described +[here](https://cloud.google.com/tpu/docs/jax-pods) for creating pod slices, or +[here](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm) for a single +v4-8 TPU. + + +First create a single TPUv4-8 VM and connect to it (you can find more detailed +instructions [here](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm)): + +```bash +ZONE=us-central1-a +TPU_TYPE=v4-8 +TPU_NAME=$USER-flax-gemma-lm1b +gcloud compute tpus tpu-vm create $TPU_NAME \ + --zone $ZONE \ + --accelerator-type $TPU_TYPE \ + --version tpu-ubuntu2204-base + +gcloud compute tpus tpu-vm ssh $TPU_NAME --zone $ZONE -- \ + -L 6006:localhost:6006 +``` + +When connected install JAX: + +```bash +pip install "jax[tpu]>=0.2.16" \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +``` + +Then install Flax + the example dependencies: + +```bash +git clone --depth=1 --branch=main https://github.com/google/flax +cd flax +pip install -e . +cd examples/gemma +pip install -r requirements.txt +``` + +In case of errors when installing example dependencies, try to upgrade existing `pip` package and downgrade `setuptools` and repeat the installation command +```bash +# Optionally +# pip install -U pip +# pip install -U "setuptools<70" +# pip install -r requirements.txt +``` + +And finally start the training: + +```bash +python3 main.py --workdir=$HOME/logs/gemma_lm1b_256 --config.per_device_batch_size=32 +``` + +Note that you might want to set `TFDS_DATA_DIR` as explained below. You probably +also want to start the long-running command above in a `tmux` session and start +some monitoring in a separate pane (note that we forwarded port 6006 locally +above): + +```bash +tensorboard --logdir=$HOME/logs +``` diff --git a/examples/gemma/configs/default.py b/examples/gemma/configs/default.py new file mode 100644 index 000000000..f05e1d1ba --- /dev/null +++ b/examples/gemma/configs/default.py @@ -0,0 +1,135 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default Hyperparameter configuration.""" + +import dataclasses + +from train import MeshRules, TrainConfig + + +@dataclasses.dataclass(unsafe_hash=True) +class Config: + # Path to load or store sentencepiece vocab file. + vocab_path: str | None = None + # Vocabulary size if `vocab_path` is not given. + vocab_size: int = 35_000 # lm1b dataset vocab size: 35913 (Gemma expected vocab size: 262_144) + # Maximum number of characters to use for training. + max_corpus_chars: int = 10**7 + # Name of TFDS translation dataset to use. + dataset_name: str = 'lm1b' + # Optional name of TFDS translation dataset to use for evaluation. + eval_dataset_name: str = 'lm1b' + # Optional name of TFDS split to use for evaluation. + eval_split: str = 'test' + # Per device batch size for training. + per_device_batch_size: int = 32 + # Per device batch size for training. + eval_per_device_batch_size: int = 32 + + # Prompt for language model sampling + prompts: tuple[str, ...] = ( + 'Paris is a the capital', + 'Flax is a', + # From train set: + 'The shutdown was aimed at creating efficiencies as', + # -> the plant was already operating at its maximum capacity of 3,000 tonnes of cellulose paste per day + 'A big theme of this hire is that there are parts of', + # -> our operations that to use a pretty trite phrase , need to be taken to the next level ... + + # From test set: + 'Because of Bear Stearns , many analysts are', + # -> raising the odds that a 2008 recession could be worse than expected + 'Next month , the Brazilian bourse', + # -> opens a London office', + ) + # Temperature for top_p sampling. + sampling_temperature: float = 0.0 + # Top-p sampling threshold. + sampling_top_p: float = 0.95 + + # Number of steps to take during training. + num_train_steps: int = 500_000 + # Number of steps to take during evaluation. + # Large enough to evaluate all samples: 306_688 / (32 * 8) = 1198 + num_eval_steps: int = 2_000 + # Number of steps to generate predictions. + # -1 will use the whole eval dataset. + num_predict_steps: int = 50 + # Base learning rate. + learning_rate: float = 0.0016 + # Linear learning rate warmup. + warmup_steps: int = 1000 + # Cross entropy loss label smoothing. + label_smoothing: float = 0.0 + # Decay factor for AdamW style weight decay. + weight_decay: float = 0.1 + # Maximum length cutoff for training examples. + max_target_length: int = 128 + # Maximum length cutoff for eval examples. + max_eval_target_length: int = 512 + + # Gemma transformer name. + # Possible values defined in transformer.TransformerConfig: + # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, ...) + transformer_name: str | None = "gemma3_1b" + # or alternatively define the model using the dict of parameters + transformer_params: dict | None = None + + # Whether to save model checkpoints. + save_checkpoints: bool = True + # Whether to restore from existing model checkpoints. + restore_checkpoints: bool = True + # Save a checkpoint every these number of steps. + checkpoint_every_steps: int = 10_000 + # Frequency of eval during training, e.g. every 1_000 steps. + eval_every_steps: int = 5_000 + # Use bfloat16 mixed precision training instead of float32. + use_bfloat16: bool = True + # Integer for PRNG random seed. + seed: int = 0 + + # Parallelism + mesh_axes: tuple[str, ...] = ('data', 'fsdp', 'tensor') + axis_rules: MeshRules = MeshRules( + embed='fsdp', + mlp='tensor', + kv='tensor', + vocab='tensor', + ) + data_sharding: tuple[str, ...] = ('data', 'fsdp') + + # One axis for each parallelism type may hold a placeholder (-1) + # value to auto-shard based on available slices and devices. + # By default, product of the DCN axes should equal number of slices + # and product of the ICI axes should equal number of devices per slice. + # ICI (Inter-Chip Interconnection): A high-speed connection between + # sets of TPU chips, which form the TPU network. + # DCN (Data Center Network): A connection between the TPU networks; + # not as fast as ICI. + # ICI has around 100x the bandwidth of DCN, but it is not a general + # purpose connection, which is why DCN is necessary for scaling to + # extremely large ML models. + dcn_data_parallelism: int = -1 + dcn_fsdp_parallelism: int = 1 + dcn_tensor_parallelism: int = 1 + ici_data_parallelism: int = 1 + ici_fsdp_parallelism: int = -1 + ici_tensor_parallelism: int = 1 + + +def get_config() -> TrainConfig: + """Get the default hyperparameter configuration.""" + config = Config() + return TrainConfig(**dataclasses.asdict(config)) diff --git a/examples/gemma/configs/gemma3_4b.py b/examples/gemma/configs/gemma3_4b.py new file mode 100644 index 000000000..5a86d3e1c --- /dev/null +++ b/examples/gemma/configs/gemma3_4b.py @@ -0,0 +1,138 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default Hyperparameter configuration.""" + +import dataclasses + +from train import MeshRules, TrainConfig + + +@dataclasses.dataclass(unsafe_hash=True) +class Config: + # Path to load or store sentencepiece vocab file. + vocab_path: str | None = None + # Vocabulary size if `vocab_path` is not given. + vocab_size: int = 35_000 # lm1b dataset vocab size: 35913 (Gemma expected vocab size: 262_144) + # Maximum number of characters to use for training. + max_corpus_chars: int = 10**7 + # Name of TFDS translation dataset to use. + dataset_name: str = 'lm1b' + # Optional name of TFDS translation dataset to use for evaluation. + eval_dataset_name: str = 'lm1b' + # Optional name of TFDS split to use for evaluation. + eval_split: str = 'test' + # Per device batch size for training. + per_device_batch_size: int = 32 + # Per device batch size for training. + eval_per_device_batch_size: int = 32 + + # Prompt for language model sampling + prompts: tuple[str, ...] = ( + 'Paris is a the capital', + 'Flax is a', + # From train set: + 'The shutdown was aimed at creating efficiencies as', + # -> the plant was already operating at its maximum capacity of 3,000 tonnes of cellulose paste per day + 'A big theme of this hire is that there are parts of', + # -> our operations that to use a pretty trite phrase , need to be taken to the next level ... + + # From test set: + 'Because of Bear Stearns , many analysts are', + # -> raising the odds that a 2008 recession could be worse than expected + 'Next month , the Brazilian bourse', + # -> opens a London office', + ) + # Temperature for top_p sampling. + sampling_temperature: float = 0.0 + # Top-p sampling threshold. + sampling_top_p: float = 0.95 + + # Number of steps to take during training. + num_train_steps: int = 500_000 + # Number of steps to take during evaluation. + # Large enough to evaluate all samples: 306_688 / (32 * 8) = 1198 + num_eval_steps: int = 2_000 + # Number of steps to generate predictions. + # -1 will use the whole eval dataset. + num_predict_steps: int = 50 + # Base learning rate. + learning_rate: float = 0.0016 + # Linear learning rate warmup. + warmup_steps: int = 1000 + # Cross entropy loss label smoothing. + label_smoothing: float = 0.0 + # Decay factor for AdamW style weight decay. + weight_decay: float = 0.1 + # Maximum length cutoff for training examples. + max_target_length: int = 128 + # Maximum length cutoff for eval examples. + max_eval_target_length: int = 512 + + # Gemma transformer name. + # Possible values defined in transformer.TransformerConfig: + # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, ...) + transformer_name: str | None = "gemma3_4b" + # or alternatively define the model using the dict of parameters + transformer_params: dict | None = None + + # Whether to save model checkpoints. + save_checkpoints: bool = True + # Whether to restore from existing model checkpoints. + restore_checkpoints: bool = True + # Save a checkpoint every these number of steps. + checkpoint_every_steps: int = 10_000 + # Frequency of eval during training, e.g. every 1_000 steps. + eval_every_steps: int = 5_000 + # Use bfloat16 mixed precision training instead of float32. + use_bfloat16: bool = True + # Integer for PRNG random seed. + seed: int = 0 + + # Parallelism + mesh_axes: tuple[str, ...] = ('data', 'fsdp', 'tensor') + axis_rules: MeshRules = MeshRules( + embed='fsdp', + mlp='tensor', + kv='tensor', + vocab='tensor', + ) + data_sharding: tuple[str, ...] = ('data', 'fsdp') + + # One axis for each parallelism type may hold a placeholder (-1) + # value to auto-shard based on available slices and devices. + # By default, product of the DCN axes should equal number of slices + # and product of the ICI axes should equal number of devices per slice. + # ICI (Inter-Chip Interconnection): A high-speed connection between + # sets of TPU chips, which form the TPU network. + # DCN (Data Center Network): A connection between the TPU networks; + # not as fast as ICI. + # ICI has around 100x the bandwidth of DCN, but it is not a general + # purpose connection, which is why DCN is necessary for scaling to + # extremely large ML models. + dcn_data_parallelism: int = -1 + dcn_fsdp_parallelism: int = 1 + dcn_tensor_parallelism: int = 1 + ici_data_parallelism: int = 1 + ici_fsdp_parallelism: int = -1 + ici_tensor_parallelism: int = 1 + + def replace(self, **kwargs): + return dataclasses.replace(self, **kwargs) + + +def get_config() -> TrainConfig: + """Get the default hyperparameter configuration.""" + config = Config() + return TrainConfig(**dataclasses.asdict(config)) diff --git a/examples/gemma/configs/small.py b/examples/gemma/configs/small.py new file mode 100644 index 000000000..b3832c625 --- /dev/null +++ b/examples/gemma/configs/small.py @@ -0,0 +1,159 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default Hyperparameter configuration.""" + +import dataclasses + +from train import MeshRules, TrainConfig + + +@dataclasses.dataclass(unsafe_hash=True) +class Config: + # Path to load or store sentencepiece vocab file. + vocab_path: str | None = None + # Vocabulary size if `vocab_path` is not given. + vocab_size: int = 35_000 # lm1b dataset vocab size: 35913 (Gemma expected vocab size: 262_144) + # Maximum number of characters to use for training. + max_corpus_chars: int = 10**7 + # Name of TFDS translation dataset to use. + dataset_name: str = 'lm1b' + # Optional name of TFDS translation dataset to use for evaluation. + eval_dataset_name: str = 'lm1b' + # Optional name of TFDS split to use for evaluation. + eval_split: str = 'test' + # Per device batch size for training. + per_device_batch_size: int = 32 + # Per device batch size for training. + eval_per_device_batch_size: int = 32 + + # Prompt for language model sampling + prompts: tuple[str, ...] = ( + 'Paris is a the capital', + 'Flax is a', + # From train set: + 'The shutdown was aimed at creating efficiencies as', + # -> the plant was already operating at its maximum capacity of 3,000 tonnes of cellulose paste per day + 'A big theme of this hire is that there are parts of', + # -> our operations that to use a pretty trite phrase , need to be taken to the next level ... + + # From test set: + 'Because of Bear Stearns , many analysts are', + # -> raising the odds that a 2008 recession could be worse than expected + 'Next month , the Brazilian bourse', + # -> opens a London office', + ) + # Temperature for top_p sampling. + sampling_temperature: float = 0.0 + # Top-p sampling threshold. + sampling_top_p: float = 0.95 + + # Number of steps to take during training. + num_train_steps: int = 500_000 + # Number of steps to take during evaluation. + num_eval_steps: int = 500 + # Number of steps to generate predictions. + # -1 will use the whole eval dataset. + num_predict_steps: int = 50 + # Base learning rate. + learning_rate: float = 0.0016 + # Linear learning rate warmup. + warmup_steps: int = 1000 + # Cross entropy loss label smoothing. + label_smoothing: float = 0.0 + # Decay factor for AdamW style weight decay. + weight_decay: float = 0.1 + # Maximum length cutoff for training examples. + max_target_length: int = 128 + # Maximum length cutoff for eval examples. + max_eval_target_length: int = 512 + + # Gemma transformer name. + # Possible values defined in transformer.TransformerConfig: + # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, ...) + transformer_name: str | None = None + # or alternatively define the model using the dict of parameters + transformer_params: dict | None = dataclasses.field( + default_factory=lambda: { + "num_layers": 6, + "embed_dim": 512, + "hidden_dim": 2048, + "num_heads": 4, + "head_dim": 256, + "num_kv_heads": 1, + "use_post_attn_norm": True, + "use_post_ffw_norm": True, + "use_qk_norm": True, + + "attention_types": (2, 2, 2, 2, 2, 1), # local_sliding, ..., local_sliding, global + "query_pre_attn_norm": 1, # QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM + + "attn_logits_soft_cap": None, + "final_logit_softcap": None, + "sliding_window_size": 128, + "transpose_gating_einsum": True, + "local_base_frequency": 10_000, + "global_base_frequency": 1_000_000, + } + ) + + # Whether to save model checkpoints. + save_checkpoints: bool = True + # Whether to restore from existing model checkpoints. + restore_checkpoints: bool = True + # Save a checkpoint every these number of steps. + checkpoint_every_steps: int = 10_000 + # Frequency of eval during training, e.g. every 1_000 steps. + eval_every_steps: int = 5_000 + # Use bfloat16 mixed precision training instead of float32. + use_bfloat16: bool = True + # Integer for PRNG random seed. + seed: int = 0 + + # Parallelism + mesh_axes: tuple[str, ...] = ('data', 'fsdp', 'tensor') + axis_rules: MeshRules = MeshRules( + embed='fsdp', + mlp='tensor', + kv='tensor', + vocab='tensor', + ) + data_sharding: tuple[str, ...] = ('data', 'fsdp') + + # One axis for each parallelism type may hold a placeholder (-1) + # value to auto-shard based on available slices and devices. + # By default, product of the DCN axes should equal number of slices + # and product of the ICI axes should equal number of devices per slice. + # ICI (Inter-Chip Interconnection): A high-speed connection between + # sets of TPU chips, which form the TPU network. + # DCN (Data Center Network): A connection between the TPU networks; + # not as fast as ICI. + # ICI has around 100x the bandwidth of DCN, but it is not a general + # purpose connection, which is why DCN is necessary for scaling to + # extremely large ML models. + dcn_data_parallelism: int = -1 + dcn_fsdp_parallelism: int = 1 + dcn_tensor_parallelism: int = 1 + ici_data_parallelism: int = 1 + ici_fsdp_parallelism: int = -1 + ici_tensor_parallelism: int = 1 + + def replace(self, **kwargs): + return dataclasses.replace(self, **kwargs) + + +def get_config() -> TrainConfig: + """Get the default hyperparameter configuration.""" + config = Config() + return TrainConfig(**dataclasses.asdict(config)) diff --git a/examples/gemma/configs/tiny.py b/examples/gemma/configs/tiny.py new file mode 100644 index 000000000..9cf4e2285 --- /dev/null +++ b/examples/gemma/configs/tiny.py @@ -0,0 +1,150 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default Hyperparameter configuration.""" + +import dataclasses + +from train import MeshRules, TrainConfig + + +@dataclasses.dataclass(unsafe_hash=True) +class Config: + # Path to load or store sentencepiece vocab file. + vocab_path: str | None = None + # Vocabulary size if `vocab_path` is not given. + vocab_size: int = 35_000 # lm1b dataset vocab size: 35913 (Gemma expected vocab size: 262_144) + # Maximum number of characters to use for training. + max_corpus_chars: int = 10**7 + # Name of TFDS translation dataset to use. + dataset_name: str = 'lm1b' + # Optional name of TFDS translation dataset to use for evaluation. + eval_dataset_name: str = 'lm1b' + # Optional name of TFDS split to use for evaluation. + eval_split: str = 'test' + # Per device batch size for training. + per_device_batch_size: int = 32 + # Per device batch size for training. + eval_per_device_batch_size: int = 32 + + # Prompt for language model sampling + prompts: tuple[str, ...] = ( + 'Paris is a the capital', + 'Flax is a', + # From train set: + 'The shutdown was aimed at creating efficiencies as', + # -> the plant was already operating at its maximum capacity of 3,000 tonnes of cellulose paste per day + 'A big theme of this hire is that there are parts of', + # -> our operations that to use a pretty trite phrase , need to be taken to the next level ... + + # From test set: + 'Because of Bear Stearns , many analysts are', + # -> raising the odds that a 2008 recession could be worse than expected + 'Next month , the Brazilian bourse', + # -> opens a London office', + ) + # Temperature for top_p sampling. + sampling_temperature: float = 0.0 + # Top-p sampling threshold. + sampling_top_p: float = 0.95 + + # Number of steps to take during training. + num_train_steps: int = 500_000 + # Number of steps to take during evaluation. + num_eval_steps: int = 500 + # Number of steps to generate predictions. + # -1 will use the whole eval dataset. + num_predict_steps: int = 20 + # Base learning rate. + learning_rate: float = 0.0016 + # Linear learning rate warmup. + warmup_steps: int = 1000 + # Cross entropy loss label smoothing. + label_smoothing: float = 0.0 + # Decay factor for AdamW style weight decay. + weight_decay: float = 0.1 + # Maximum length cutoff for training examples. + max_target_length: int = 128 + # Maximum length cutoff for eval examples. + max_eval_target_length: int = 512 + + # Gemma transformer name. + # Possible values defined in transformer.TransformerConfig: + # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, ...) + transformer_name: str | None = None + # or alternatively define the model using the dict of parameters + transformer_params: dict | None = dataclasses.field( + default_factory=lambda: { + "num_layers": 4, + "embed_dim": 256, + "hidden_dim": 256 * 4 // 2, # embed_dim * num_heads // 2 + "num_heads": 4, + "head_dim": 128, + "num_kv_heads": 1, + "use_post_attn_norm": False, + "use_post_ffw_norm": False, + "attention_types": (1, 1, 1, 1), # global * num_layers + "final_logit_softcap": None, + } + ) + + # Whether to save model checkpoints. + save_checkpoints: bool = True + # Whether to restore from existing model checkpoints. + restore_checkpoints: bool = True + # Save a checkpoint every these number of steps. + checkpoint_every_steps: int = 10_000 + # Frequency of eval during training, e.g. every 1_000 steps. + eval_every_steps: int = 5_000 + # Use bfloat16 mixed precision training instead of float32. + use_bfloat16: bool = True + # Integer for PRNG random seed. + seed: int = 0 + + # Parallelism + mesh_axes: tuple[str, ...] = ('data', 'fsdp', 'tensor') + axis_rules: MeshRules = MeshRules( + embed='fsdp', + mlp='tensor', + kv='tensor', + vocab='tensor', + ) + data_sharding: tuple[str, ...] = ('data', 'fsdp') + + # One axis for each parallelism type may hold a placeholder (-1) + # value to auto-shard based on available slices and devices. + # By default, product of the DCN axes should equal number of slices + # and product of the ICI axes should equal number of devices per slice. + # ICI (Inter-Chip Interconnection): A high-speed connection between + # sets of TPU chips, which form the TPU network. + # DCN (Data Center Network): A connection between the TPU networks; + # not as fast as ICI. + # ICI has around 100x the bandwidth of DCN, but it is not a general + # purpose connection, which is why DCN is necessary for scaling to + # extremely large ML models. + dcn_data_parallelism: int = -1 + dcn_fsdp_parallelism: int = 1 + dcn_tensor_parallelism: int = 1 + ici_data_parallelism: int = 1 + ici_fsdp_parallelism: int = -1 + ici_tensor_parallelism: int = 1 + + def replace(self, **kwargs): + return dataclasses.replace(self, **kwargs) + + +def get_config() -> TrainConfig: + """Get the default hyperparameter configuration.""" + config = Config() + return TrainConfig(**dataclasses.asdict(config)) diff --git a/examples/gemma/input_pipeline.py b/examples/gemma/input_pipeline.py new file mode 100644 index 000000000..da9ae4733 --- /dev/null +++ b/examples/gemma/input_pipeline.py @@ -0,0 +1,381 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Input pipeline for a LM1B dataset.""" + +import os +import typing + +import tensorflow as tf +import tensorflow_datasets as tfds +import tokenizer +from clu import deterministic_data + +if typing.TYPE_CHECKING: + from train import TrainConfig + +AUTOTUNE = tf.data.experimental.AUTOTUNE +Features = dict[str, tf.Tensor] + + +class NormalizeFeatureNamesOp: + """Normalizes feature names to 'inputs' and 'targets'.""" + + def __call__(self, features: Features) -> Features: + features['inputs'] = features.pop('text') + # Unnecessary step used for uniformizing with examples/wmt. + features['targets'] = features['inputs'] + return features + + +def get_raw_dataset(dataset_name: str, split: str) -> tf.data.Dataset: + """Loads a raw text dataset and normalizes feature keys. + + Args: + dataset_name: TFDS dataset name. + split: Split to use. This must be the full split. We shard the split across + multiple hosts and currently don't support sharding subsplits. + + Returns: + Dataset with source and target language features mapped to 'inputs' and + 'targets'. + """ + split = tfds.split_for_jax_process(split, drop_remainder=True) + ds = tfds.load(dataset_name, split=split) + ds = ds.map(NormalizeFeatureNamesOp(), num_parallel_calls=AUTOTUNE) + return ds + + +def pack_dataset( + dataset: tf.data.Dataset, + key2length: int | dict[str, int], + keys: list[str] | None = None, +) -> tf.data.Dataset: + """Creates a 'packed' version of a dataset on-the-fly. + + Adapted from the mesh-tf implementation. + + This is meant to replace the irritation of having to create a separate + "packed" version of a dataset to train efficiently on TPU. + Each example in the output dataset represents several examples in the + input dataset. + For each key in the input dataset, two additional keys are created: + _segmentation: an int32 tensor identifying the parts + representing the original example. + _position: an int32 tensor identifying the position within the original + example. + Example: + Two input examples get combined to form an output example. + The input examples are: + {"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]} + {"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]} + The output example is: + { + "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0] + "inputs_segmentation": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0] + "inputs_position": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0] + "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0] + "targets_segmentation": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0] + "targets_position": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0] + } + 0 represents padding in both the inputs and the outputs. + Sequences in the incoming examples are truncated to length "length", and the + sequences in the output examples all have fixed (padded) length "length". + + Args: + dataset: a tf.data.Dataset + key2length: an integer, or a dict from feature-key to integer + keys: a list of strings (e.g. ["inputs", "targets"]) + + Returns: + a tf.data.Dataset + """ + shapes = tf.nest.map_structure(lambda spec: spec.shape, dataset.element_spec) + if keys is None: + keys = list(shapes.keys()) + for k in keys: + if k not in shapes: + raise ValueError( + 'Key %s not found in dataset. Available keys are %s' + % (k, shapes.keys()) + ) + if not shapes[k].is_compatible_with(tf.TensorShape([None])): # type: ignore[wrong-arg-types] + raise ValueError('Tensors to be packed must be one-dimensional.') + # make sure that the length dictionary contains all keys as well as the + # keys suffixed by "_segmentation" and "_position" + if isinstance(key2length, int): + key2length = {k: key2length for k in keys} + for k in keys: + for suffix in ['_segmentation', '_position']: + key2length[k + suffix] = key2length[k] + + # trim to length + dataset = dataset.map( + lambda x: {k: x[k][: key2length[k]] for k in keys}, + num_parallel_calls=AUTOTUNE, + ) + # Setting batch_size=length ensures that the concatenated sequences (if they + # have length >=1) are sufficient to fill at least one packed example. + batch_size = max(key2length.values()) + dataset = dataset.padded_batch( + batch_size, padded_shapes={k: [-1] for k in keys} + ) + dataset = _pack_with_tf_ops(dataset, keys, key2length) + + # Set the Tensor shapes correctly since they get lost in the process. + def my_fn(x): + return {k: tf.reshape(v, [key2length[k]]) for k, v in x.items()} + + return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) + + +def _pack_with_tf_ops( + dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int] +) -> tf.data.Dataset: + """Helper-function for packing a dataset which has already been batched. + + Helper for pack_dataset() Uses tf.while_loop. + + Args: + dataset: a dataset containing padded batches of examples. + keys: a list of strings + key2length: a dict from feature-key to integer + + Returns: + a dataset. + """ + empty_example = {} + for k in keys: + empty_example[k] = tf.zeros([0], dtype=tf.int32) + empty_example[k + '_position'] = tf.zeros([0], dtype=tf.int32) + keys_etc = empty_example.keys() + + def write_packed_example(partial, outputs): + new_partial = empty_example.copy() + new_outputs = {} + for k in keys_etc: + new_outputs[k] = outputs[k].write( + outputs[k].size(), + tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), + ) + return new_partial, new_outputs + + def map_fn(x): + """Internal function to flat_map over. + + Consumes a batch of input examples and produces a variable number of output + examples. + Args: + x: a single example + + Returns: + a tf.data.Dataset + """ + partial = empty_example.copy() + i = tf.zeros([], dtype=tf.int32) + dynamic_batch_size = tf.shape(x[keys[0]])[0] + outputs = {} + for k in keys: + outputs[k] = tf.TensorArray( + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) + outputs[k + '_position'] = tf.TensorArray( + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) + + def body_fn(i, partial, outputs): + """Body function for while_loop. + + Args: + i: integer scalar + partial: dictionary of Tensor (partially-constructed example) + outputs: dictionary of TensorArray + + Returns: + A triple containing the new values of the inputs. + """ + can_append = True + one_example = {} + for k in keys: + val = tf.cast(x[k][i], tf.int32) + val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] + one_example[k] = val + for k in keys: + can_append = tf.logical_and( + can_append, + tf.less_equal( + tf.size(partial[k]) + tf.size(one_example[k]), key2length[k] + ), + ) + + def false_fn(): + return write_packed_example(partial, outputs) + + def true_fn(): + return partial, outputs + + partial, outputs = tf.cond(can_append, true_fn, false_fn) + new_partial = {} + for k in keys: + new_seq = one_example[k][: key2length[k]] + new_seq_len = tf.size(new_seq) + new_partial[k] = tf.concat([partial[k], new_seq], 0) + new_partial[k + '_position'] = tf.concat( + [partial[k + '_position'], tf.range(new_seq_len)], 0 + ) + partial = new_partial + return i + 1, partial, outputs + + # For loop over all examples in the batch. + i, partial, outputs = tf.while_loop( + cond=lambda *_: True, + body=body_fn, + loop_vars=(i, partial, outputs), + shape_invariants=( + tf.TensorShape([]), + {k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types] + {k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types] + ), + maximum_iterations=dynamic_batch_size, + ) + _, outputs = write_packed_example(partial, outputs) + packed = {k: outputs[k].stack() for k in keys_etc} + for k in keys: + packed[k + '_segmentation'] = tf.cumsum( + tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1 + ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32) + return packed + + dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) + return dataset.unbatch() + + +def shift_data_by_truncation(x): + # https://github.com/AI-Hypercomputer/maxtext/blob/7fe1de75b3919c0fda00d23ad6cb29def9098362/MaxText/input_pipeline/_input_pipeline_utils.py#L53 + x["inputs"] = x["inputs"][:-1] + x["targets"] = x["targets"][1:] + return x + + +# ----------------------------------------------------------------------------- +# Main dataset prep routines. +# ----------------------------------------------------------------------------- +def preprocess_data( + dataset, + shuffle: bool, + num_epochs: int | None = 1, + pack_examples: bool = True, + shuffle_buffer_size: int = 1024, + max_length: int = 512, + batch_size: int = 256, + drop_remainder: bool = True, + prefetch_size: int = AUTOTUNE, + shift: bool = True, +): + """Shuffle and batch/pack the given dataset.""" + + def length_filter(max_len): + def filter_fn(x): + source, target = x['inputs'], x['targets'] + l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0]) + return tf.less(l, max_len + 1) + + return filter_fn + + if max_length > 0: + dataset = dataset.filter(length_filter(max_length)) + + if shuffle: + dataset = dataset.shuffle(shuffle_buffer_size) + dataset = dataset.repeat(num_epochs) + + # Shift inputs for teacher-forced training + if shift: + dataset = dataset.map( + shift_data_by_truncation, num_parallel_calls=AUTOTUNE, deterministic=True + ) + + if pack_examples: + dataset = pack_dataset(dataset, max_length) + dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) + else: # simple (static-shape) padded batching + dataset = dataset.padded_batch( + batch_size, + padded_shapes={'inputs': max_length, 'targets': max_length}, + padding_values={'inputs': 0, 'targets': 0}, + drop_remainder=drop_remainder, + ) + + if prefetch_size: + dataset = dataset.prefetch(prefetch_size) + + return dataset + + +def get_datasets( + config: "TrainConfig", + *, + n_devices: int, + vocab_path: str | None = None, +): + """Load and return dataset of batched examples for use during training.""" + if vocab_path is None: + vocab_path = os.path.expanduser('~/lm1b_sentencepiece_model') + + train_data = get_raw_dataset(config.dataset_name, 'train') + + if config.eval_dataset_name: + eval_dataset_name = config.eval_dataset_name + else: + eval_dataset_name = config.dataset_name + eval_data = get_raw_dataset(eval_dataset_name, config.eval_split) + + # Tokenize data. + sp_processor = tokenizer.load_or_train_tokenizer( + train_data, + vocab_path=vocab_path, + vocab_size=config.vocab_size, + max_corpus_chars=config.max_corpus_chars, + ) + train_data = train_data.map( + tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE + ) + eval_data = eval_data.map( + tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE + ) + + batch_size = config.per_device_batch_size * n_devices + if config.eval_per_device_batch_size > 0: + eval_batch_size = config.eval_per_device_batch_size * n_devices + else: + eval_batch_size = batch_size + + train_ds = preprocess_data( + train_data, + shuffle=True, + num_epochs=None, + pack_examples=True, + batch_size=batch_size, + max_length=config.max_target_length, + ) + + eval_ds = preprocess_data( + eval_data, + shuffle=False, + pack_examples=False, + batch_size=eval_batch_size, + max_length=config.max_eval_target_length, + ) + + return train_ds, eval_ds, sp_processor diff --git a/examples/gemma/input_pipeline_test.py b/examples/gemma/input_pipeline_test.py new file mode 100644 index 000000000..8e188866e --- /dev/null +++ b/examples/gemma/input_pipeline_test.py @@ -0,0 +1,89 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib +import tempfile + +from absl.testing import absltest +import tensorflow_datasets as tfds + +from configs import default +import input_pipeline + +# We just use different values here to verify that the input pipeline uses the +# the correct value for the 3 different datasets. +_TARGET_LENGTH = 32 +_EVAL_TARGET_LENGTH = 48 + + +class InputPipelineTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.train_ds, self.eval_ds = self._get_datasets() + + def _get_datasets(self): + config = default.get_config() + config.per_device_batch_size = 1 + config.eval_per_device_batch_size = 2 + config.vocab_size = 32 + config.max_corpus_chars = 1000 + config.max_target_length = _TARGET_LENGTH + config.max_eval_target_length = _EVAL_TARGET_LENGTH + + vocab_path = os.path.join(tempfile.mkdtemp(), 'sentencepiece_model') + + # Go two directories up to the root of the flax directory. + # "/path/to/flax/examples/lm1b_nnx/models_test.py" -> "/path/to/flax" + flax_root_dir = pathlib.Path(__file__).absolute().parents[2] + data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable + + with tfds.testing.mock_data(num_examples=128, data_dir=data_dir): + train_ds, eval_ds, _ = input_pipeline.get_datasets( + n_devices=2, config=config, vocab_path=vocab_path + ) + return train_ds, eval_ds + + def test_train_ds(self): + expected_shape = [2, _TARGET_LENGTH] # 2 devices. + # For training we pack multiple short examples in one example. + # *_position and *_segmentation indicate the boundaries. + for batch in self.train_ds.take(3): + self.assertEqual( + {k: v.shape.as_list() for k, v in batch.items()}, + { + 'inputs': expected_shape, + 'inputs_position': expected_shape, + 'inputs_segmentation': expected_shape, + 'targets': expected_shape, + 'targets_position': expected_shape, + 'targets_segmentation': expected_shape, + }, + ) + + def test_eval_ds(self): + expected_shape = [4, _EVAL_TARGET_LENGTH] # 2 devices. + for batch in self.eval_ds.take(3): + self.assertEqual( + {k: v.shape.as_list() for k, v in batch.items()}, + { + 'inputs': expected_shape, + 'targets': expected_shape, + }, + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/examples/gemma/layers.py b/examples/gemma/layers.py index 5dafbd15d..dbc782c8d 100644 --- a/examples/gemma/layers.py +++ b/examples/gemma/layers.py @@ -20,7 +20,6 @@ from typing import Any, Union from flax import nnx -import flax.linen as nn import jax.numpy as jnp from jaxtyping import Array, ArrayLike # pylint: disable=g-importing-member,g-multiple-import @@ -31,9 +30,17 @@ class Einsum(nnx.Module): """Einsum is a convenience module for parameterized tensor multiplication.""" - def __init__(self, einsum_str: str, shape: Shape, *, rngs: nnx.Rngs): + def __init__( + self, + einsum_str: str, + shape: Shape, + *, + kernel_init: nnx.Initializer = nnx.initializers.normal(), + rngs: nnx.Rngs, + dtype: Any = jnp.float32, + ): self.einsum_str = einsum_str - self.w = nnx.Param(nn.initializers.normal()(rngs.params(), shape)) + self.w = nnx.Param(kernel_init(rngs.params(), shape, dtype)) def __call__(self, x: ArrayLike) -> Array: return jnp.einsum(self.einsum_str, x, self.w.value) @@ -46,12 +53,20 @@ def shape(self) -> Shape: class RMSNorm(nnx.Module): """RMSNorm layer.""" - def __init__(self, dim: int, *, rngs: nnx.Rngs): - self.scale = nnx.Param(nn.initializers.zeros_init()(rngs.params(), dim)) + def __init__( + self, + dim: int, + *, + scale_init: nnx.Initializer = nnx.initializers.zeros_init(), + rngs: nnx.Rngs, + dtype: Any = jnp.float32, + ): + self.scale = nnx.Param(scale_init(rngs.params(), dim, dtype)) def __call__(self, x: Array) -> Array: + dtype = self.scale.value.dtype var = jnp.mean(jnp.square(x), axis=-1, keepdims=True) - normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) + normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06)), dtype) # normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is # a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to # a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs. diff --git a/examples/gemma/main.py b/examples/gemma/main.py new file mode 100644 index 000000000..f4185e216 --- /dev/null +++ b/examples/gemma/main.py @@ -0,0 +1,66 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main file for training Gemma model on the One Billion Word Benchmark dataset. + +This file is intentionally kept short. The majority for logic is in libraries +that can be easily tested and imported in Colab. +""" + +import jax +import tensorflow as tf +import train +from absl import app, flags, logging +from clu import platform +from ml_collections import config_flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string('workdir', None, 'Directory to store model data.') +config_flags.DEFINE_config_file( + 'config', + 'configs/default.py', + 'File path to the training hyperparameter configuration.', + lock_config=True, +) +flags.mark_flags_as_required(['workdir']) + + +def main(argv): + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make + # it unavailable to JAX. + tf.config.experimental.set_visible_devices([], 'GPU') + + logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) + logging.info('JAX local devices: %r', jax.local_devices()) + + # Add a note so that we can tell which task is which JAX host. + # (Depending on the platform task 0 is not guaranteed to be host 0) + platform.work_unit().set_task_status( + f'process_index: {jax.process_index()}, ' + f'process_count: {jax.process_count()}' + ) + platform.work_unit().create_artifact( + platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' + ) + + train.train_and_evaluate(FLAGS.config, FLAGS.workdir) + + +if __name__ == '__main__': + jax.config.config_with_absl() + app.run(main) diff --git a/examples/gemma/modules.py b/examples/gemma/modules.py index fd8fff319..4fc4a9b0d 100644 --- a/examples/gemma/modules.py +++ b/examples/gemma/modules.py @@ -24,7 +24,6 @@ import layers import positional_embeddings import sow_lib -import flax.linen as nn import jax import jax.numpy as jnp from jaxtyping import Array, ArrayLike # pylint: disable=g-importing-member,g-multiple-import @@ -50,10 +49,12 @@ def __init__( vocab_size: int, embed_dim: int, *, + embedding_init: nnx.Initializer = nnx.initializers.normal(), + dtype: Any = jnp.float32, rngs: nnx.Rngs, ): self.input_embedding = nnx.Param( - nn.initializers.normal()(rngs.params(), (vocab_size, embed_dim)) + embedding_init(rngs.params(), (vocab_size, embed_dim), dtype) ) def encode(self, x: ArrayLike) -> Array: @@ -91,7 +92,14 @@ def __init__( attn_logits_soft_cap: float | None = None, sliding_window_size: int | None = None, use_qk_norm: bool = False, - sow_config: sow_lib.SowConfig = sow_lib.SowConfig() + sow_config: sow_lib.SowConfig = sow_lib.SowConfig(), + dtype: Any = jnp.float16, + kernel_init: nnx.Initializer = nnx.initializers.normal(), + scale_init: nnx.Initializer = nnx.initializers.zeros_init(), + attn_vec_einsum_kernel_init: nnx.Initializer | None = None, + qkv_einsum_kernel_init: nnx.Initializer | None = None, + q_einsum_kernel_init: nnx.Initializer | None = None, + kv_einsum_kernel_init: nnx.Initializer | None = None, ): if attn_type == AttentionType.LOCAL_SLIDING and sliding_window_size is None: raise ValueError( @@ -102,9 +110,12 @@ def __init__( self.attn_type = attn_type self.sliding_window_size = sliding_window_size self.attn_logits_soft_cap = attn_logits_soft_cap + attn_vec_einsum_kernel_init = attn_vec_einsum_kernel_init if attn_vec_einsum_kernel_init else kernel_init self.attn_vec_einsum = layers.Einsum( einsum_str='BTNH,NHD->BTD', shape=(num_heads, head_dim, features), + kernel_init=attn_vec_einsum_kernel_init, + dtype=dtype, rngs=rngs, ) self.rope_base_frequency = rope_base_frequency @@ -113,25 +124,44 @@ def __init__( self.sow_config = sow_config if num_heads == num_kv_heads: + qkv_einsum_kernel_init = qkv_einsum_kernel_init if qkv_einsum_kernel_init else kernel_init self.qkv_einsum = layers.Einsum( einsum_str='BTD,SNDH->SBTNH', shape=(3, num_heads, features, head_dim), + kernel_init=qkv_einsum_kernel_init, + dtype=dtype, rngs=rngs, ) else: + q_einsum_kernel_init = q_einsum_kernel_init if q_einsum_kernel_init else kernel_init + kv_einsum_kernel_init = kv_einsum_kernel_init if kv_einsum_kernel_init else kernel_init self.q_einsum = layers.Einsum( einsum_str='BTD,NDH->BTNH', shape=(num_heads, features, head_dim), + kernel_init=kernel_init, + dtype=dtype, rngs=rngs, ) self.kv_einsum = layers.Einsum( einsum_str='BSD,CKDH->CBSKH', shape=(2, num_kv_heads, features, head_dim), + kernel_init=kv_einsum_kernel_init, + dtype=dtype, rngs=rngs, ) if self.use_qk_norm: - self._query_norm = layers.RMSNorm(head_dim, rngs=rngs) - self._key_norm = layers.RMSNorm(head_dim, rngs=rngs) + self._query_norm = layers.RMSNorm( + head_dim, + scale_init=scale_init, + dtype=dtype, + rngs=rngs, + ) + self._key_norm = layers.RMSNorm( + head_dim, + scale_init=scale_init, + dtype=dtype, + rngs=rngs, + ) def __call__( self, @@ -266,29 +296,34 @@ def __init__( features: int, hidden_dim: int, *, + kernel_init: nnx.Initializer = nnx.initializers.normal(), rngs: nnx.Rngs, - sow_config: sow_lib.SowConfig = sow_lib.SowConfig() + sow_config: sow_lib.SowConfig = sow_lib.SowConfig(), + dtype: Any = jnp.float32, ): self.gate_proj = nnx.Linear( in_features=features, out_features=hidden_dim, use_bias=False, rngs=rngs, - kernel_init=nn.initializers.normal(), + kernel_init=kernel_init, + dtype=dtype, ) self.up_proj = nnx.Linear( in_features=features, out_features=hidden_dim, use_bias=False, rngs=rngs, - kernel_init=nn.initializers.normal(), + kernel_init=kernel_init, + dtype=dtype, ) self.down_proj = nnx.Linear( in_features=hidden_dim, out_features=features, use_bias=False, rngs=rngs, - kernel_init=nn.initializers.normal(), + kernel_init=kernel_init, + dtype=dtype, ) self.sow_config = sow_config @@ -309,25 +344,42 @@ class Block(nnx.Module): def __init__( self, - num_heads: int, - num_kv_heads: int, - embed_dim: int, - head_dim: int, - hidden_dim: int, - use_post_attn_norm: bool, - use_post_ffw_norm: bool, - query_pre_attn_scalar: float, + config, # TransformerConfig attn_type: AttentionType, *, rngs: nnx.Rngs, - rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY, - rope_scale_factor: float = DEFAULT_ROPE_SCALE_FACTOR, - attn_logits_soft_cap: float | None = None, - sliding_window_size: int | None = None, - use_qk_norm: bool = False, - sow_config: sow_lib.SowConfig = sow_lib.SowConfig() + sow_config: sow_lib.SowConfig = sow_lib.SowConfig(), ): - self.pre_attention_norm = layers.RMSNorm(embed_dim, rngs=rngs) + num_heads = config.num_heads + num_kv_heads = config.num_kv_heads + embed_dim = config.embed_dim + head_dim = config.head_dim + hidden_dim = config.hidden_dim + sliding_window_size = config.sliding_window_size + use_post_attn_norm = config.use_post_attn_norm + use_post_ffw_norm = config.use_post_ffw_norm + query_pre_attn_scalar = config.query_pre_attn_scalar() + if attn_type == AttentionType.LOCAL_SLIDING: + rope_base_frequency = config.local_base_frequency + rope_scale_factor = config.local_scale_factor + else: + rope_base_frequency = config.global_base_frequency + rope_scale_factor = config.global_scale_factor + + attn_logits_soft_cap = config.attn_logits_soft_cap + use_qk_norm = config.use_qk_norm + dtype = config.dtype + + self.pre_attention_norm = layers.RMSNorm( + embed_dim, + scale_init=maybe_with_partitioning( + nnx.initializers.zeros_init(), + config.axis_rules, + ("embed", ), + ), + rngs=rngs, + dtype=dtype, + ) self.attn = Attention( num_heads=num_heads, num_kv_heads=num_kv_heads, @@ -342,21 +394,79 @@ def __init__( rngs=rngs, use_qk_norm=use_qk_norm, sow_config=sow_config, + attn_vec_einsum_kernel_init=maybe_with_partitioning( + nnx.initializers.normal(), + config.axis_rules, + (None, "embed", "kv"), # sharded array shape: (num_heads, head_dim, features) + ), + qkv_einsum_kernel_init=maybe_with_partitioning( + nnx.initializers.normal(), + config.axis_rules, + (None, None, "embed", "kv"), # sharded array shape: (3, num_heads, features, head_dim) + ), + q_einsum_kernel_init=maybe_with_partitioning( + nnx.initializers.normal(), + config.axis_rules, + (None, "embed", "kv"), # sharded array shape: (num_heads, features, head_dim) + ), + kv_einsum_kernel_init=maybe_with_partitioning( + nnx.initializers.normal(), + config.axis_rules, + (None, None, "embed", "kv"), # sharded array shape: (2, num_kv_heads, features, head_dim) + ), + scale_init=maybe_with_partitioning( + nnx.initializers.zeros_init(), + config.axis_rules, + ("embed", ), + ), + dtype=dtype, ) if use_post_attn_norm: - self.post_attention_norm = layers.RMSNorm(embed_dim, rngs=rngs) + self.post_attention_norm = layers.RMSNorm( + embed_dim, + scale_init=maybe_with_partitioning( + nnx.initializers.zeros_init(), + config.axis_rules, + ("embed", ), + ), + rngs=rngs, + dtype=dtype, + ) else: self.post_attention_norm = None - self.pre_ffw_norm = layers.RMSNorm(embed_dim, rngs=rngs) + self.pre_ffw_norm = layers.RMSNorm( + embed_dim, + scale_init=maybe_with_partitioning( + nnx.initializers.zeros_init(), + config.axis_rules, + ("embed", ), + ), + rngs=rngs, + dtype=dtype, + ) self.mlp = FeedForward( features=embed_dim, hidden_dim=hidden_dim, + kernel_init=maybe_with_partitioning( + nnx.initializers.normal(), + config.axis_rules, + ("embed", "mlp"), + ), rngs=rngs, sow_config=sow_config, ) if use_post_ffw_norm: - self.post_ffw_norm = layers.RMSNorm(embed_dim, rngs=rngs) + self.post_ffw_norm = layers.RMSNorm( + embed_dim, + scale_init=maybe_with_partitioning( + nnx.initializers.zeros_init(), + config.axis_rules, + ("embed", ), + ), + rngs=rngs, + dtype=dtype, + ) else: self.post_ffw_norm = None self.sow_config = sow_config @@ -403,3 +513,9 @@ def init_cache( batch_size=batch_size, dtype=dtype, ) + + +def maybe_with_partitioning(fn, axis_rules, axis_rules_args=()): + if axis_rules is None: + return fn + return nnx.with_partitioning(fn, axis_rules(*axis_rules_args)) diff --git a/examples/gemma/modules_test.py b/examples/gemma/modules_test.py index 6f3140a58..5e606b85a 100644 --- a/examples/gemma/modules_test.py +++ b/examples/gemma/modules_test.py @@ -18,6 +18,7 @@ from absl.testing import parameterized from flax import nnx import modules +import transformer as transformer_lib import jax import jax.numpy as jnp import numpy as np @@ -269,16 +270,23 @@ def test_block( inputs = jnp.ones((batch_size, 1, embed_dim)) attn_mask = jnp.ones((batch_size, 1, cache_size)) + config = transformer_lib.TransformerConfig( + num_heads=num_heads, + num_kv_heads=num_heads, + embed_dim=embed_dim, + head_dim=head_dim, + hidden_dim=1, + use_post_attn_norm=use_post_attn_norm, + use_post_ffw_norm=use_post_ffw_norm, + final_logit_softcap=None, + num_layers=-1, + num_embed=-1, + attention_types=[], + ) + block = modules.Block( - num_heads, - num_heads, - embed_dim, - head_dim, - 1, - use_post_attn_norm, - use_post_ffw_norm, - 1.0, - modules.AttentionType.GLOBAL, + config, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) cache = block.init_cache( @@ -313,27 +321,40 @@ def test_post_attention_norm( inputs = jnp.ones((batch_size, 1, embed_dim)) attn_mask = jnp.ones((batch_size, 1, cache_size)) - normed_block = modules.Block( - num_heads, - num_heads, - embed_dim, - head_dim, - 1, + normed_block_config = transformer_lib.TransformerConfig( + num_heads=num_heads, + num_kv_heads=num_heads, + embed_dim=embed_dim, + head_dim=head_dim, + hidden_dim=1, use_post_attn_norm=True, use_post_ffw_norm=False, - query_pre_attn_scalar=1.0, + final_logit_softcap=None, + num_layers=-1, + num_embed=-1, + attention_types=[], + ) + normed_block = modules.Block( + normed_block_config, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) - unnormed_block = modules.Block( - num_heads, - num_heads, - embed_dim, - head_dim, - 1, + + unnormed_block_config = transformer_lib.TransformerConfig( + num_heads=num_heads, + num_kv_heads=num_heads, + embed_dim=embed_dim, + head_dim=head_dim, + hidden_dim=1, use_post_attn_norm=False, use_post_ffw_norm=False, - query_pre_attn_scalar=1.0, + final_logit_softcap=None, + num_layers=-1, + num_embed=-1, + attention_types=[], + ) + unnormed_block = modules.Block( + unnormed_block_config, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -373,27 +394,40 @@ def test_post_ffw_norm( inputs = jnp.ones((batch_size, 1, embed_dim)) attn_mask = jnp.ones((batch_size, 1, cache_size)) - normed_block = modules.Block( - num_heads, - num_heads, - embed_dim, - head_dim, - 1, + normed_block_config = transformer_lib.TransformerConfig( + num_heads=num_heads, + num_kv_heads=num_heads, + embed_dim=embed_dim, + head_dim=head_dim, + hidden_dim=1, use_post_attn_norm=False, use_post_ffw_norm=True, - query_pre_attn_scalar=1.0, + final_logit_softcap=None, + num_layers=-1, + num_embed=-1, + attention_types=[], + ) + normed_block = modules.Block( + normed_block_config, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) - unnormed_block = modules.Block( - num_heads, - num_heads, - embed_dim, - head_dim, - 1, + + unnormed_block_config = transformer_lib.TransformerConfig( + num_heads=num_heads, + num_kv_heads=num_heads, + embed_dim=embed_dim, + head_dim=head_dim, + hidden_dim=1, use_post_attn_norm=False, use_post_ffw_norm=False, - query_pre_attn_scalar=1.0, + final_logit_softcap=None, + num_layers=-1, + num_embed=-1, + attention_types=[], + ) + unnormed_block = modules.Block( + unnormed_block_config, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -411,6 +445,9 @@ def test_post_ffw_norm( all_outputs.append(outputs) normed_output, unnormed_output = all_outputs # pylint: disable=unbalanced-tuple-unpacking + print(normed_output.shape, unnormed_output.shape) + print(f"{normed_output=}") + print(f"{unnormed_output=}") self.assertTrue(jnp.not_equal(normed_output, unnormed_output).all()) diff --git a/examples/gemma/requirements.txt b/examples/gemma/requirements.txt new file mode 100644 index 000000000..68c75e1a2 --- /dev/null +++ b/examples/gemma/requirements.txt @@ -0,0 +1,12 @@ +absl-py~=2.2 +clu==0.0.12 +flax~=0.10 +jax~=0.6 +mlcroissant~=1.0 +numpy~=2.1 +optax~=0.2 +sentencepiece~=0.2 +jaxtyping~=0.3 +tensorflow~=2.19 +tensorflow-datasets~=4.9 +tensorflow-text~=2.19 \ No newline at end of file diff --git a/examples/gemma/sampler.py b/examples/gemma/sampler.py index 4ebdd9e40..2b201a834 100644 --- a/examples/gemma/sampler.py +++ b/examples/gemma/sampler.py @@ -411,7 +411,7 @@ def __call__( input_strings: input prompts to feed to the model for sampling. total_generation_steps: number of generation steps. will correspond to the longest prompt in the batch. - echo: whgether to return the prompt as part of the output sample. + echo: whether to return the prompt as part of the output sample. return_logits: whether to return per-step logits used during generation. forbidden_tokens: list of tokens that are forbidden to be generated. Each token must map to a single token id in the vocab. diff --git a/examples/gemma/tokenizer.py b/examples/gemma/tokenizer.py new file mode 100644 index 000000000..fe740f3be --- /dev/null +++ b/examples/gemma/tokenizer.py @@ -0,0 +1,194 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provides op for tokenizing a dataset.""" + +import dataclasses +import os +import tempfile +import time +from typing import Any +from collections.abc import Iterable + +import jax +import tensorflow as tf +import tensorflow_text as tftxt +from absl import logging +from sentencepiece import SentencePieceTrainer, SentencePieceProcessor + +Features = dict[str, tf.Tensor] + + +def _dump_chars_to_textfile( + dataset: tf.data.Dataset, + maxchars: int = int(1e7), + data_keys=('inputs', 'targets'), +) -> tuple[str, int]: + """Write part of a TFDS sentence dataset to lines in a text file. + + Args: + dataset: tf.dataset containing string-data. + maxchars: int: approximate number of characters to save from dataset. + data_keys: Tuple[str]: what keys in dataset to dump from. + + Returns: + name of temp file with dataset bytes, exact number of characters dumped. + """ + char_count = 0 + ds_iter = dataset.as_numpy_iterator() + with tempfile.NamedTemporaryFile( + delete=False, prefix='/tmp/ds_chars' + ) as outfp: + while char_count < maxchars: + example = next(ds_iter) + for k in data_keys: + line = example[k] + b'\n' + char_count += len(line) + outfp.write(line) + return outfp.name, char_count + + +def _train_sentencepiece( + dataset: tf.data.Dataset, + *, + vocab_size: int, + maxchars: int = int(1e7), + model_path: str, + model_type: str = 'unigram', + character_coverage: float = 1.0, + data_keys=('inputs', 'targets'), + pad_id: int = 0, + eos_id: int = 1, + bos_id: int = 2, + unk_id: int = 3, +): + """Train SentencePiece tokenizer from subset of tf dataset. + + Args: + dataset: tf.dataset + vocab_size: int: size of vocab tokens to train. + maxchars: int: number of characters to use for sentencepiece training. + model_path: str: path of model file to save vocab model to. + model_type: str: type of sentencepiece vocab to train. + character_coverage: amount of characters covered by the model, good defaults + are 0.9995 for languages with rich character set like Japanese or Chinese + and 1.0 for other languages with small character set. + data_keys: Tuple[str]: keys of dataset to use for training. + pad_id: int: pad piece id + eos_id: int: end of sentence piece id + bos_id: int: begin of sentence piece id + unk_id: int: unknown piece id + + Returns: + path to the trained sentencepiece vocabulary model. + """ + if model_path.startswith('gs://'): + abs_model_path = model_path + else: + abs_model_path = os.path.abspath(os.path.expanduser(model_path)) + fname, _ = _dump_chars_to_textfile( + dataset, maxchars=maxchars, data_keys=data_keys + ) + with tempfile.NamedTemporaryFile( + delete=False, prefix='/tmp/sp_tmp' + ) as model_fp: + pass # we just want a prefix'd tmp-filename + argstr = ' '.join( + [ + f'--input={fname}', + f'--vocab_size={vocab_size}', + f'--character_coverage={character_coverage}', + f'--model_prefix={model_fp.name}', + f'--model_type={model_type}', + # Setup ids for PAD, EOS, BOS, UNK as 0, 1, 2, 3 + # Default values: + # --unk_id (Override UNK () id.) type: int32 default: 0 + # --bos_id (Override BOS () id. Set -1 to disable BOS.) type: int32 default: 1 + # --eos_id (Override EOS () id. Set -1 to disable EOS.) type: int32 default: 2 + # --pad_id (Override PAD () id. Set -1 to disable PAD.) type: int32 default: -1 + # https://github.com/google/sentencepiece/blob/master/doc/options.md + f'--pad_id={pad_id}', + f'--bos_id={bos_id}', + f'--eos_id={eos_id}', + f'--unk_id={unk_id}', + ] + ) + SentencePieceTrainer.Train(argstr) + if jax.process_index() == 0: + # Use an intermediate filename that is renamed to the target name to address + # create and fill delays. + copy_rename_path = abs_model_path + '.rntmp' + tf.io.gfile.copy(model_fp.name + '.model', copy_rename_path, overwrite=True) + tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True) + logging.info('copied %s to %s', model_fp.name + '.model', abs_model_path) + else: + while not tf.io.gfile.exists(abs_model_path): + time.sleep(1) + time.sleep(1) + return abs_model_path + + +def _load_sentencepiece_tokenizer( + model_path: str, + add_bos: bool = False, + add_eos: bool = True, + reverse: bool = False, +): + """Load a tf-text SentencePiece tokenizer from given model filepath.""" + with tf.io.gfile.GFile(model_path, 'rb') as model_fp: + sp_model = model_fp.read() + sp_tokenizer = tftxt.SentencepieceTokenizer( + model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse + ) + return sp_tokenizer + + +def load_or_train_tokenizer( + dataset: tf.data.Dataset, + *, + vocab_path: str, + vocab_size: int, + max_corpus_chars: int, + data_keys: tuple[str, str] = ('inputs', 'targets'), +): + """Loads the tokenizer at `vocab_path` or trains a one from `dataset`.""" + try: + return _load_sentencepiece_tokenizer(vocab_path) + except tf.errors.NotFoundError: + logging.info('SentencePiece vocab not found, building one from data.') + vocab_path = _train_sentencepiece( + dataset, + vocab_size=vocab_size, + maxchars=max_corpus_chars, + model_path=vocab_path, + data_keys=data_keys, + ) + return _load_sentencepiece_tokenizer(vocab_path) + + +@dataclasses.dataclass +class TokenizeOp: + sp_tokenizer: Any + data_keys: Iterable[str] = ('inputs', 'targets') + + def __call__(self, features: Features) -> Features: + for k in self.data_keys: + features[k] = self.sp_tokenizer.tokenize(features[k]) + return features + + +def load_sentencepiece_processor(vocab_path: str): + spp = SentencePieceProcessor() + spp.load(vocab_path) + return spp diff --git a/examples/gemma/train.py b/examples/gemma/train.py new file mode 100644 index 000000000..b5bc07745 --- /dev/null +++ b/examples/gemma/train.py @@ -0,0 +1,603 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Language Modeling example. + +This script trains a Transformer on a LM1B dataset. +""" + +# pytype: disable=wrong-arg-count +# pytype: disable=attribute-error + +import dataclasses +import os + +import input_pipeline +import jax +import jax.numpy as jnp +import tokenizer +import transformer as transformer_lib +import numpy as np +import optax +import sampler as sampler_lib +import tensorflow as tf +import utils +from absl import logging +from clu import metric_writers, periodic_actions +from jax import random +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from utils import TrainState + +from flax import nnx +from flax.training import checkpoints, common_utils + + +@dataclasses.dataclass(unsafe_hash=True) +class MeshRules: + embed: str | None = None + mlp: str | None = None + kv: str | None = None + vocab: str | None = None + + def __call__(self, *keys: str) -> tuple[str, ...]: + return tuple( + getattr(self, key) if key is not None else None + for key in keys + ) + + +@dataclasses.dataclass(unsafe_hash=True) +class TrainConfig: + # Path to load or store sentencepiece vocab file. + vocab_path: str | None + # Vocabulary size if `vocab_path` is not given. + vocab_size: int + # Maximum number of characters to use for training. + max_corpus_chars: int + # Name of TFDS translation dataset to use. + dataset_name: str + # Optional name of TFDS translation dataset to use for evaluation. + eval_dataset_name: str + # Optional name of TFDS split to use for evaluation. + eval_split: str + # Per device batch size for training. + per_device_batch_size: int + # Per device batch size for training. + eval_per_device_batch_size: int + + # Prompt for language model sampling + prompts: tuple[str, ...] + # Temperature for top_p sampling. + sampling_temperature: float + # Top-p sampling threshold. + sampling_top_p: float + + # Number of steps to take during training. + num_train_steps: int + # Number of steps to take during evaluation. + # Large enough to evaluate all samples: 306_688 / (32 * 8) = 1198 + num_eval_steps: int + # Number of steps to generate predictions. + # -1 will use the whole eval dataset. + num_predict_steps: int + # Base learning rate. + learning_rate: float + # Linear learning rate warmup. + warmup_steps: int + # Cross entropy loss label smoothing. + label_smoothing: float + # Decay factor for AdamW style weight decay. + weight_decay: float + # Maximum length cutoff for training examples. + max_target_length: int + # Maximum length cutoff for eval examples. + max_eval_target_length: int + + # Gemma transformer name. + # Possible values defined in transformer.TransformerConfig: + # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, ...) + transformer_name: str | None + # or alternatively define the model using the dict of parameters + transformer_params: dict | None + + # Whether to save model checkpoints. + save_checkpoints: bool + # Whether to restore from existing model checkpoints. + restore_checkpoints: bool + # Save a checkpoint every these number of steps. + checkpoint_every_steps: int + # Frequency of eval during training, e.g. every 1_000 steps. + eval_every_steps: int + # Use bfloat16 mixed precision training instead of float32. + use_bfloat16: bool + # Integer for PRNG random seed. + seed: int + + # Parallelism + mesh_axes: tuple[str, ...] + axis_rules: MeshRules + data_sharding: tuple[str, ...] + + # One axis for each parallelism type may hold a placeholder (-1) + # value to auto-shard based on available slices and devices. + # By default, product of the DCN axes should equal number of slices + # and product of the ICI axes should equal number of devices per slice. + # ICI (Inter-Chip Interconnection): A high-speed connection between + # sets of TPU chips, which form the TPU network. + # DCN (Data Center Network): A connection between the TPU networks; + # not as fast as ICI. + # ICI has around 100x the bandwidth of DCN, but it is not a general + # purpose connection, which is why DCN is necessary for scaling to + # extremely large ML models. + dcn_data_parallelism: int = -1 + dcn_fsdp_parallelism: int = 1 + dcn_tensor_parallelism: int = 1 + ici_data_parallelism: int = 1 + ici_fsdp_parallelism: int = -1 + ici_tensor_parallelism: int = 1 + + def replace(self, **kwargs): + return dataclasses.replace(self, **kwargs) + + def __post_init__(self): + if isinstance(self.axis_rules, dict): + self.axis_rules = MeshRules(**self.axis_rules) + + +def rsqrt_schedule( + init_value: float, + shift: int = 0, +): + """Applies a reverse square-root schedule. + + The reverse square root schedule is simply `lr = init_value / sqrt(step)`. + + Args: + init_value: Base learning rate (before applying the rsqrt schedule). + shift: How many steps the rsqrt should be shifted. Shifting the rsqrt + schedule makes it less steep in the beginning (close to 0). + + Returns: + A schedule that applies the reverse square root. + """ + + def schedule(count): + return init_value * (count + shift) ** -0.5 * shift**0.5 + + return schedule + + +def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): + """Creates a rsqrt schedule with linear warmup.""" + return optax.join_schedules( + [ + optax.linear_schedule( + init_value=0, + end_value=learning_rate, + transition_steps=warmup_steps, + ), + rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), + ], + boundaries=[warmup_steps], + ) + + +def compute_weighted_cross_entropy( + logits, targets, weights=None, label_smoothing=0.0 +): + """Compute weighted cross entropy and entropy for log probs and targets. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array. + weights: None or array of shape [batch, length]. + label_smoothing: label smoothing constant, used to determine the on and off + values. + + Returns: + Tuple of scalar loss and batch normalizing factor. + """ + if logits.ndim != targets.ndim + 1: + raise ValueError( + 'Incorrect shapes. Got shape %s logits and %s targets' + % (str(logits.shape), str(targets.shape)) + ) + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_targets = common_utils.onehot( + targets, vocab_size, on_value=confidence, off_value=low_confidence + ) + + loss = -jnp.sum(soft_targets * nnx.log_softmax(logits), axis=-1) + loss = loss - normalizing_constant + + normalizing_factor = np.prod(targets.shape) + if weights is not None: + loss = loss * weights + normalizing_factor = weights.sum() + + return loss.sum(), normalizing_factor + + +def compute_weighted_accuracy(logits, targets, weights=None): + """Compute weighted accuracy for log probs and targets. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array. + weights: None or array of shape [batch, length] + + Returns: + Tuple of scalar loss and batch normalizing factor. + """ + if logits.ndim != targets.ndim + 1: + raise ValueError( + 'Incorrect shapes. Got shape %s logits and %s targets' + % (str(logits.shape), str(targets.shape)) + ) + loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) + normalizing_factor = np.prod(logits.shape[:-1]) + if weights is not None: + loss = loss * weights + normalizing_factor = weights.sum() + + return loss.sum(), normalizing_factor + + +def compute_metrics(logits, labels, weights, label_smoothing=0.0): + """Compute summary metrics.""" + loss, weight_sum = compute_weighted_cross_entropy( + logits, labels, weights, label_smoothing + ) + acc, _ = compute_weighted_accuracy(logits, labels, weights) + metrics = { + 'loss': loss, + 'accuracy': acc, + 'denominator': weight_sum, + } + return metrics + + +# Primary training / eval / decode step functions. +# ----------------------------------------------------------------------------- + + +def train_step( + state: TrainState, + batch, + learning_rate_fn, + label_smoothing=0.0, +): + """Perform a single training step.""" + # X_position and X_segmentation are needed only when using "packed examples" + # where multiple sequences are packed into the same example with this + # metadata. + # if such features are not present they are ignored and the example is treated + # like a normal, unpacked sequence example. + train_keys = ['inputs', 'inputs_position', 'inputs_segmentation', 'targets'] + (inputs, inputs_positions, inputs_segmentation, targets) = ( + batch.get(k, None) for k in train_keys + ) + + # TODO: this should be defined globally + pad_id = 0 + weights = jnp.where(inputs > pad_id, 1, 0).astype(jnp.float32) + input_mask = inputs > pad_id + attention_mask = transformer_lib.make_causal_attn_mask(input_mask) # (B, L, L) + # inputs_segmentation: (B, L) + mask = inputs_segmentation[:, :, None] == inputs_segmentation[:, None, :] # (B, L, L) + attention_mask = jnp.logical_and(mask, attention_mask) + + def loss_fn(params): + """loss function used for training.""" + module = nnx.merge(state.graphdef, params) + + logits, _ = module( + inputs, + positions=inputs_positions, + attention_mask=attention_mask, + cache=None, + ) + + loss, weight_sum = compute_weighted_cross_entropy( + logits, targets, weights, label_smoothing + ) + mean_loss = loss / weight_sum + return mean_loss, logits + + step = state.step + lr = learning_rate_fn(step) + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (_, logits), grads = grad_fn(state.params) + new_state = state.apply_gradients(grads=grads) + metrics = compute_metrics(logits, targets, weights) + metrics['learning_rate'] = lr + + return new_state, metrics + + +def eval_step( + params: nnx.State, + batch, + graphdef: nnx.GraphDef[transformer_lib.Transformer], + label_smoothing=0.0, +): + """Calculate evaluation metrics on a batch.""" + inputs, targets = batch['inputs'], batch['targets'] + + # TODO: this should be defined globally + pad_id = 0 + weights = jnp.where(inputs > pad_id, 1, 0).astype(jnp.float32) + input_mask = inputs > pad_id + inputs_positions = transformer_lib.build_positions_from_mask(input_mask) + attention_mask = transformer_lib.make_causal_attn_mask(input_mask) + + module = nnx.merge(graphdef, params) + logits, _ = module( + inputs, + positions=inputs_positions, + attention_mask=attention_mask, + cache=None, + ) + + return compute_metrics(logits, targets, weights, label_smoothing) + + +def evaluate( + *, + jit_eval_step, + state: TrainState, + eval_ds: tf.data.Dataset, + num_eval_steps: int, +): + """Evaluate the target an return a dictionary with the metrics.""" + logging.info('Gathering evaluation metrics.') + eval_metrics = [] + eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types + for _, eval_batch in zip(range(num_eval_steps), eval_iter): + eval_batch = jax.tree.map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access + metrics = jit_eval_step(state.params, eval_batch, state.graphdef) + eval_metrics.append(metrics) + eval_metrics = common_utils.stack_forest(eval_metrics) + eval_metrics_sums = jax.tree.map(jnp.sum, eval_metrics) + eval_denominator = eval_metrics_sums.pop('denominator') + eval_summary = jax.tree.map( + lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop + eval_metrics_sums, + ) + return eval_summary + + +def train_and_evaluate(config: TrainConfig, workdir: str): + """Runs a training and evaluation loop. + + Args: + config: Configuration to use. + workdir: Working directory for checkpoints and TF summaries. If this + contains checkpoint training will be resumed from the latest checkpoint. + """ + workdir = os.path.abspath(workdir) + tf.io.gfile.makedirs(workdir) + + vocab_path = config.vocab_path + if vocab_path is None: + vocab_path = os.path.join(workdir, 'sentencepiece_model') + config.vocab_path = vocab_path + tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) + + # Load Dataset + # --------------------------------------------------------------------------- + logging.info('Initializing dataset.') + train_ds, eval_ds, encoder = input_pipeline.get_datasets( + n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path + ) + + train_iter = iter(train_ds) + vocab_size = int(encoder.vocab_size()) + + logging.info('Initializing model, optimizer, and step functions.') + # Build Model and Optimizer + # --------------------------------------------------------------------------- + if config.transformer_name is not None: + model_config = transformer_lib.TransformerConfig.from_version_name( + config.transformer_name, + num_embed=vocab_size, + dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, + axis_rules=config.axis_rules, + ) + else: + assert config.transformer_params is not None + model_config = transformer_lib.TransformerConfig.from_dict( + **config.transformer_params, + num_embed=vocab_size, + dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, + axis_rules=config.axis_rules, + ) + + # Mesh definition + devices_array = utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + start_step = 0 + rng = jax.random.PRNGKey(config.seed) + rng, init_rng = jax.random.split(rng) + rng, inference_rng = random.split(rng) + + def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): + return transformer_lib.Transformer(config, rngs=nnx.Rngs(params=key)) + + learning_rate_fn = create_learning_rate_schedule( + learning_rate=config.learning_rate, warmup_steps=config.warmup_steps + ) + + optimizer = optax.adamw( + learning_rate_fn, + b1=0.9, + b2=0.98, + eps=1e-9, + weight_decay=config.weight_decay, + ) + + state, state_sharding = utils.setup_initial_state( + constructor, optimizer, model_config, init_rng, mesh + ) + data_sharding = NamedSharding(mesh, P(config.data_sharding)) + + if config.restore_checkpoints: + # Restore unreplicated optimizer + model state from last checkpoint. + state = checkpoints.restore_checkpoint(workdir, state) + # Grab last step. + start_step = int(state.step) + + writer = metric_writers.create_default_writer( + workdir, just_logging=jax.process_index() > 0 + ) + if start_step == 0: + writer.write_hparams(dataclasses.asdict(config)) + + # compile multidevice versions of train/eval/predict step fn. + jit_train_step = jax.jit( + train_step, + in_shardings=( + state_sharding, + data_sharding, + ), # type: ignore + out_shardings=(state_sharding, None), # type: ignore + static_argnames=("learning_rate_fn", "label_smoothing"), + donate_argnums=0, + ) + + jit_eval_step = jax.jit( + eval_step, + in_shardings=( + state_sharding.params, + data_sharding, + ), # type: ignore + out_shardings=None, # type: ignore + static_argnames=("graphdef", "label_smoothing"), + ) + + vocab = tokenizer.load_sentencepiece_processor(vocab_path) + sampler = sampler_lib.Sampler( + transformer=nnx.merge(state.graphdef, state.params), + vocab=vocab, + cache_size=1024, + ) + + # Main Train Loop + # --------------------------------------------------------------------------- + + # We init the first set of dropout PRNG keys, but update it afterwards inside + # the main pmap'd training update for performance. + logging.info('Starting training loop.') + hooks = [] + report_progress = periodic_actions.ReportProgress( + num_train_steps=config.num_train_steps, writer=writer + ) + if jax.process_index() == 0: + hooks += [ + report_progress, + periodic_actions.Profile(logdir=workdir, num_profile_steps=5), + ] + train_metrics = [] + with metric_writers.ensure_flushes(writer): + for step in range(start_step, config.num_train_steps): + is_last_step = step == config.num_train_steps - 1 + + # Shard data to devices and do a training step. + with jax.profiler.StepTraceAnnotation('train', step_num=step): + with report_progress.timed('data'): + batch = next(train_iter) + batch = jax.tree.map(lambda x: jnp.asarray(x, device=data_sharding), batch) + + with report_progress.timed('train_step'): + state, metrics = jit_train_step( + state, batch, learning_rate_fn, 0.0 + ) + train_metrics.append(metrics) + + # Quick indication that training is happening. + logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) + for h in hooks: + h(step) + + # Write batch loss and lr every step to TB + # without overwhelming the stdout: + if jax.process_index() == 0: + tb_writer = writer._writers[-1] + lr = train_metrics[-1]['learning_rate'] + train_batch_loss = train_metrics[-1]['loss'] + denominator = train_metrics[-1]['denominator'] + tb_writer.write_scalars(step, { + "train_learning_rate": lr, + "train_loss": train_batch_loss / denominator, + }) + + # Periodic metric handling. + if (step > 0 and step % config.eval_every_steps == 0) or is_last_step: + with report_progress.timed('training_metrics'): + logging.info('Gathering training metrics.') + train_metrics = common_utils.stack_forest(train_metrics) + # Remove learning_rate from the summary + _ = train_metrics.pop('learning_rate') + metrics_sums = jax.tree.map(jnp.sum, train_metrics) + denominator = metrics_sums.pop('denominator') + summary = jax.tree.map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop + summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), max=1.0e4) + summary = {'train_' + k: v for k, v in summary.items()} + writer.write_scalars(step, summary) + train_metrics = [] + + with report_progress.timed('generate_text'): + # update sampler's transformer state: + sampler.transformer_state = state.params + exemplars = sampler( + config.prompts, + total_generation_steps=config.num_predict_steps, + temperature=config.sampling_temperature, + top_p=config.sampling_top_p, + seed=inference_rng, + echo=True, + ) + writer.write_texts(step, {'samples': exemplars.text}) + + with report_progress.timed('eval'): + eval_results = evaluate( + jit_eval_step=jit_eval_step, + state=state, + eval_ds=eval_ds, + num_eval_steps=config.num_eval_steps, + ) + # (clipped) perplexity after averaging log-perplexity + eval_results['perplexity'] = jnp.clip( + jnp.exp(eval_results['loss']), max=1.0e4 + ) + writer.write_scalars( + step, {'eval_' + k: v for k, v in eval_results.items()} + ) + + # Save a checkpoint on one host after every checkpoint_freq steps. + save_checkpoint = ( + step % config.checkpoint_every_steps == 0 or is_last_step + ) + if config.save_checkpoints and save_checkpoint: + logging.info('Saving checkpoint step %d.', step) + with report_progress.timed('checkpoint'): + checkpoints.save_checkpoint_multiprocess(workdir, state, step) diff --git a/examples/gemma/transformer.py b/examples/gemma/transformer.py index 4ab1d558c..7c215209e 100644 --- a/examples/gemma/transformer.py +++ b/examples/gemma/transformer.py @@ -105,6 +105,8 @@ class TransformerConfig: global_scale_factor: float = modules.DEFAULT_ROPE_SCALE_FACTOR use_qk_norm: bool = False sliding_window_size: int | None = None + dtype: Any = jnp.float32 + axis_rules: Any | None = None def query_pre_attn_scalar(self) -> float: """Returns the scalar to multiply the query by before attention.""" @@ -180,215 +182,270 @@ def from_params(cls, params: params_lib.Params) -> TransformerConfig: raise ValueError('Could not determine Gemma variant from params.') @classmethod - def gemma_2b(cls): - num_layers = _NUM_LAYERS_GEMMA_2B - return cls( - num_layers=num_layers, - num_embed=256128, - embed_dim=2048, - hidden_dim=16384, - num_heads=8, - head_dim=256, - num_kv_heads=1, - final_logit_softcap=None, - attention_types=(modules.AttentionType.GLOBAL,) * num_layers, - use_post_attn_norm=False, - use_post_ffw_norm=False, + def from_version_name(cls, name: str, **override) -> TransformerConfig: + possible_names = ( + "gemma_2b", "gemma_7b", + "gemma2_2b", "gemma2_9b", "gemma2_27b", + "gemma3_1b", "gemma3_4b", "gemma3_12b", "gemma3_27b", ) + if name not in possible_names: + raise ValueError( + f'Unknown version name: {name}. ' + f'Please choose one of the following: {possible_names}' + ) + if hasattr(cls, name): + model_config = getattr(cls, name)(**override) + return model_config + else: + raise RuntimeError( + 'Something wrong in TransformerConfig code. ' + f'No attribute {name} in TransformerConfig' + ) @classmethod - def gemma_7b(cls): + def from_dict(cls, **config: Any) -> TransformerConfig: + # Deserialize query_pre_attn_norm values: + if "query_pre_attn_norm" in config: + config["query_pre_attn_norm"] = QueryPreAttentionNormalisation(config["query_pre_attn_norm"]) + else: + config["query_pre_attn_norm"] = QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM + return cls(**config) + + @classmethod + def gemma_2b(cls, **override) -> TransformerConfig: + num_layers = _NUM_LAYERS_GEMMA_2B + config = { + 'num_layers': num_layers, + 'num_embed': 256128, + 'embed_dim': 2048, + 'hidden_dim': 16384, + 'num_heads': 8, + 'head_dim': 256, + 'num_kv_heads': 1, + 'final_logit_softcap': None, + 'attention_types': (modules.AttentionType.GLOBAL,) * num_layers, + 'use_post_attn_norm': False, + 'use_post_ffw_norm': False, + } + for key, value in override.items(): + config[key] = value + return cls(**config) + + @classmethod + def gemma_7b(cls, **override): num_layers = _NUM_LAYERS_GEMMA_7B - return cls( - num_layers=num_layers, - num_embed=256128, - embed_dim=3072, - hidden_dim=24576, - num_heads=16, - head_dim=256, - num_kv_heads=16, - final_logit_softcap=None, - attention_types=(modules.AttentionType.GLOBAL,) * num_layers, - use_post_attn_norm=False, - use_post_ffw_norm=False, - ) + config = { + "num_layers": num_layers, + "num_embed": 256128, + "embed_dim": 3072, + "hidden_dim": 24576, + "num_heads": 16, + "head_dim": 256, + "num_kv_heads": 16, + "final_logit_softcap": None, + "attention_types": (modules.AttentionType.GLOBAL,) * num_layers, + "use_post_attn_norm": False, + "use_post_ffw_norm": False, + } + for key, value in override.items(): + config[key] = value + return cls(**config) @classmethod - def gemma2_2b(cls): + def gemma2_2b(cls, **override): num_layers = _NUM_LAYERS_GEMMA2_2B - return cls( - num_layers=num_layers, - num_embed=256128, - embed_dim=2304, - hidden_dim=9216, - num_heads=8, - head_dim=256, - num_kv_heads=4, - final_logit_softcap=30.0, - attention_types=( - modules.AttentionType.LOCAL_SLIDING, - modules.AttentionType.GLOBAL, - ) - * int(num_layers / 2), - use_post_attn_norm=True, - use_post_ffw_norm=True, - query_pre_attn_norm=QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, - attn_logits_soft_cap=50.0, - sliding_window_size=4096, - ) + config = { + 'num_layers': num_layers, + 'num_embed': 256128, + 'embed_dim': 2304, + 'hidden_dim': 9216, + 'num_heads': 8, + 'head_dim': 256, + 'num_kv_heads': 4, + 'final_logit_softcap': 30.0, + 'attention_types': ( + modules.AttentionType.LOCAL_SLIDING, + modules.AttentionType.GLOBAL, + ) + * int(num_layers / 2), + 'use_post_attn_norm': True, + 'use_post_ffw_norm': True, + 'query_pre_attn_norm': QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, + 'attn_logits_soft_cap': 50.0, + 'sliding_window_size': 4096, + } + for key, value in override.items(): + config[key] = value + return cls(**config) @classmethod - def gemma2_9b(cls): + def gemma2_9b(cls, **override): num_layers = _NUM_LAYERS_GEMMA2_9B - return cls( - num_layers=num_layers, - num_embed=256128, - embed_dim=3584, - hidden_dim=28672, - num_heads=16, - head_dim=256, - num_kv_heads=8, - final_logit_softcap=30.0, - attention_types=( - modules.AttentionType.LOCAL_SLIDING, - modules.AttentionType.GLOBAL, - ) - * int(num_layers / 2), - use_post_attn_norm=True, - use_post_ffw_norm=True, - attn_logits_soft_cap=50.0, - sliding_window_size=4096, - ) + config = { + "num_layers": num_layers, + "num_embed": 256128, + "embed_dim": 3584, + "hidden_dim": 28672, + "num_heads": 16, + "head_dim": 256, + "num_kv_heads": 8, + "final_logit_softcap": 30.0, + "attention_types": ( + modules.AttentionType.LOCAL_SLIDING, + modules.AttentionType.GLOBAL, + ) * int(num_layers / 2), + "use_post_attn_norm": True, + "use_post_ffw_norm": True, + "attn_logits_soft_cap": 50.0, + "sliding_window_size": 4096, + } + for key, value in override.items(): + config[key] = value + return cls(**config) @classmethod - def gemma2_27b(cls): + def gemma2_27b(cls, **override): num_layers = _NUM_LAYERS_GEMMA2_27B - return cls( - num_layers=num_layers, - num_embed=256128, - embed_dim=4608, - hidden_dim=72728, - num_heads=32, - head_dim=128, - num_kv_heads=16, - final_logit_softcap=30.0, - use_post_attn_norm=True, - use_post_ffw_norm=True, - attention_types=( - modules.AttentionType.LOCAL_SLIDING, - modules.AttentionType.GLOBAL, - ) - * int(num_layers / 2), - attn_logits_soft_cap=50.0, - sliding_window_size=4096, - ) + config = { + "num_layers": num_layers, + "num_embed": 256128, + "embed_dim": 4608, + "hidden_dim": 72728, + "num_heads": 32, + "head_dim": 128, + "num_kv_heads": 16, + "final_logit_softcap": 30.0, + "use_post_attn_norm": True, + "use_post_ffw_norm": True, + "attention_types": ( + modules.AttentionType.LOCAL_SLIDING, + modules.AttentionType.GLOBAL, + ) * int(num_layers / 2), + "attn_logits_soft_cap": 50.0, + "sliding_window_size": 4096, + } + for key, value in override.items(): + config[key] = value + return cls(**config) @classmethod - def gemma3_1b(cls): + def gemma3_1b(cls, **override): num_layers = _NUM_LAYERS_GEMMA3_1B - return cls( - num_layers=num_layers, - final_logit_softcap=None, - num_embed=262144, - embed_dim=1152, - hidden_dim=6 * 1152, - num_heads=4, - head_dim=256, - num_kv_heads=1, - use_post_attn_norm=True, - use_post_ffw_norm=True, - use_qk_norm=True, - attention_types=make_attention_layers_types( - GEMMA3_ATTENTION_PATTERN, num_layers - ), - query_pre_attn_norm=QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, - attn_logits_soft_cap=None, - sliding_window_size=512, - transpose_gating_einsum=True, - local_base_frequency=10_000, - global_base_frequency=1_000_000, - ) + config = { + "num_layers": num_layers, + "final_logit_softcap": None, + "num_embed": 262144, + "embed_dim": 1152, + "hidden_dim": 6 * 1152, + "num_heads": 4, + "head_dim": 256, + "num_kv_heads": 1, + "use_post_attn_norm": True, + "use_post_ffw_norm": True, + "use_qk_norm": True, + "attention_types": make_attention_layers_types( + GEMMA3_ATTENTION_PATTERN, num_layers + ), + "query_pre_attn_norm": QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, + "attn_logits_soft_cap": None, + "sliding_window_size": 512, + "transpose_gating_einsum": True, + "local_base_frequency": 10_000, + "global_base_frequency": 1_000_000, + } + for key, value in override.items(): + config[key] = value + return cls(**config) @classmethod - def gemma3_4b(cls): + def gemma3_4b(cls, **override): num_layers = _NUM_LAYERS_GEMMA3_4B - return cls( - num_layers=num_layers, - final_logit_softcap=None, - num_embed=262_144, - embed_dim=2560, - hidden_dim=2560 * 8 // 2, - num_heads=8, - head_dim=256, - num_kv_heads=4, - use_post_attn_norm=True, - use_post_ffw_norm=True, - use_qk_norm=True, - attention_types=make_attention_layers_types( - GEMMA3_ATTENTION_PATTERN, num_layers - ), - query_pre_attn_norm=QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, - attn_logits_soft_cap=None, - sliding_window_size=1024, - transpose_gating_einsum=True, - local_base_frequency=10_000, - global_base_frequency=1_000_000, - global_scale_factor=8.0, - ) + config = { + "num_layers": num_layers, + "final_logit_softcap": None, + "num_embed": 262_144, + "embed_dim": 2560, + "hidden_dim": 2560 * 8 // 2, + "num_heads": 8, + "head_dim": 256, + "num_kv_heads": 4, + "use_post_attn_norm": True, + "use_post_ffw_norm": True, + "use_qk_norm": True, + "attention_types": make_attention_layers_types( + GEMMA3_ATTENTION_PATTERN, num_layers + ), + "query_pre_attn_norm": QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, + "attn_logits_soft_cap": None, + "sliding_window_size": 1024, + "transpose_gating_einsum": True, + "local_base_frequency": 10_000, + "global_base_frequency": 1_000_000, + "global_scale_factor": 8.0, + } + for key, value in override.items(): + config[key] = value + return cls(**config) @classmethod - def gemma3_12b(cls): + def gemma3_12b(cls, **override): num_layers = _NUM_LAYERS_GEMMA3_12B - return cls( - num_layers=num_layers, - final_logit_softcap=None, - num_embed=262144, - embed_dim=30 * 128, - hidden_dim=8 * 30 * 128 // 2, - num_heads=16, - head_dim=256, - num_kv_heads=8, - use_post_attn_norm=True, - use_post_ffw_norm=True, - use_qk_norm=True, - attention_types=make_attention_layers_types( - GEMMA3_ATTENTION_PATTERN, num_layers - ), - query_pre_attn_norm=QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, - attn_logits_soft_cap=None, - sliding_window_size=1024, - transpose_gating_einsum=True, - local_base_frequency=10_000, - global_base_frequency=1_000_000, - global_scale_factor=8.0, - ) + config = { + "num_layers": num_layers, + "final_logit_softcap": None, + "num_embed": 262144, + "embed_dim": 30 * 128, + "hidden_dim": 8 * 30 * 128 // 2, + "num_heads": 16, + "head_dim": 256, + "num_kv_heads": 8, + "use_post_attn_norm": True, + "use_post_ffw_norm": True, + "use_qk_norm": True, + "attention_types": make_attention_layers_types( + GEMMA3_ATTENTION_PATTERN, num_layers + ), + "query_pre_attn_norm": QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, + "attn_logits_soft_cap": None, + "sliding_window_size": 1024, + "transpose_gating_einsum": True, + "local_base_frequency": 10_000, + "global_base_frequency": 1_000_000, + "global_scale_factor": 8.0, + } + for key, value in override.items(): + config[key] = value + return cls(**config) @classmethod - def gemma3_27b(cls): + def gemma3_27b(cls, **override): num_layers = _NUM_LAYERS_GEMMA3_27B - return cls( - num_layers=num_layers, - final_logit_softcap=None, - num_embed=262144, - embed_dim=5376, - hidden_dim=5376 * 8 // 2, - num_heads=32, - head_dim=128, - num_kv_heads=16, - use_post_attn_norm=True, - use_post_ffw_norm=True, - use_qk_norm=True, - attention_types=make_attention_layers_types( - GEMMA3_ATTENTION_PATTERN, num_layers - ), - query_pre_attn_norm=QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS, - attn_logits_soft_cap=None, - sliding_window_size=1024, - transpose_gating_einsum=True, - local_base_frequency=10_000, - global_base_frequency=1_000_000, - global_scale_factor=8.0, - ) + config = { + "num_layers": num_layers, + "final_logit_softcap": None, + "num_embed": 262144, + "embed_dim": 5376, + "hidden_dim": 5376 * 8 // 2, + "num_heads": 32, + "head_dim": 128, + "num_kv_heads": 16, + "use_post_attn_norm": True, + "use_post_ffw_norm": True, + "use_qk_norm": True, + "attention_types": make_attention_layers_types( + GEMMA3_ATTENTION_PATTERN, num_layers + ), + "query_pre_attn_norm": QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS, + "attn_logits_soft_cap": None, + "sliding_window_size": 1024, + "transpose_gating_einsum": True, + "local_base_frequency": 10_000, + "global_base_frequency": 1_000_000, + "global_scale_factor": 8.0, + } + for key, value in override.items(): + config[key] = value + return cls(**config) def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]: @@ -464,36 +521,34 @@ def __init__( self.embedder = modules.Embedder( vocab_size=config.num_embed, embed_dim=config.embed_dim, + embedding_init=modules.maybe_with_partitioning( + nnx.initializers.normal(), + config.axis_rules, + ("vocab", "embed"), + ), + dtype=config.dtype, rngs=rngs, ) self.layers = [ modules.Block( - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - embed_dim=config.embed_dim, - head_dim=config.head_dim, - hidden_dim=config.hidden_dim, - sliding_window_size=config.sliding_window_size, - use_post_attn_norm=config.use_post_attn_norm, - use_post_ffw_norm=config.use_post_ffw_norm, - attn_logits_soft_cap=config.attn_logits_soft_cap, - attn_type=attn_type, - query_pre_attn_scalar=config.query_pre_attn_scalar(), - rngs=rngs, - rope_base_frequency=config.local_base_frequency - if attn_type == modules.AttentionType.LOCAL_SLIDING - else config.global_base_frequency, - rope_scale_factor=config.local_scale_factor - if attn_type == modules.AttentionType.LOCAL_SLIDING - else config.global_scale_factor, - use_qk_norm=config.use_qk_norm, - sow_config=sow_config, + config=config, + attn_type=attn_type, + sow_config=sow_config, + rngs=rngs, ) for _, attn_type in zip( range(config.num_layers), config.attention_types ) ] - self.final_norm = layers.RMSNorm(config.embed_dim, rngs=rngs) + self.final_norm = layers.RMSNorm( + config.embed_dim, + scale_init=modules.maybe_with_partitioning( + nnx.initializers.zeros_init(), + config.axis_rules, + ("embed", ), + ), + rngs=rngs, + ) self.final_logits_softcap = config.final_logit_softcap self.sow_config = sow_config diff --git a/examples/gemma/utils.py b/examples/gemma/utils.py new file mode 100644 index 000000000..8e9041ed1 --- /dev/null +++ b/examples/gemma/utils.py @@ -0,0 +1,172 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied over from MaxText (https://github.com/google/maxtext/blob/main/MaxText/max_utils.py). + +import logging +from typing import Any, TYPE_CHECKING +from collections.abc import Callable + +import jax +import jax.numpy as jnp +import numpy as np +from jax.experimental import mesh_utils +from transformer import TransformerConfig, Transformer + +from flax import nnx +from flax.training import train_state + +if TYPE_CHECKING: + from train import TrainConfig + +Dtype = Any +Shape = tuple[int, ...] + + +class TrainState(train_state.TrainState): + graphdef: nnx.GraphDef[Transformer] + + +# Mesh utils. +# ----------------------------------------------------------------------------- + + +def create_device_mesh(config: "TrainConfig"): + """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas.""" + devices = jax.devices() + num_devices = len(devices) + try: + num_slices = 1 + max([d.slice_index for d in devices]) + except AttributeError: + num_slices = 1 + num_devices_per_slice = num_devices // num_slices + logging.info(f'Devices: {devices}') + logging.info(f'Number of devices: {num_devices}') + + multi_slice_env = hasattr(jax.devices()[0], 'slice_index') + + dcn_parallelism = [ + config.dcn_data_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_tensor_parallelism, + ] + ici_parallelism = [ + config.ici_data_parallelism, + config.ici_fsdp_parallelism, + config.ici_tensor_parallelism, + ] + + # Find possible unspecified parallelisms + dcn_parallelism = fill_unspecified_mesh_axes( + dcn_parallelism, num_slices, 'DCN' + ) + ici_parallelism = fill_unspecified_mesh_axes( + ici_parallelism, num_devices_per_slice, 'ICI' + ) + + if multi_slice_env: + mesh = mesh_utils.create_hybrid_device_mesh( + ici_parallelism, dcn_parallelism + ) + else: + mesh = mesh_utils.create_device_mesh(ici_parallelism) + + logging.info(f'Decided on mesh: {mesh}') + logging.info(f'Mesh shape: {mesh.shape}') + + return mesh + + +def fill_unspecified_mesh_axes( + parallelism_vals, target_product, parallelism_type +): + """Evaluates unspecified DCN/ICI parallelism values""" + if -1 in parallelism_vals: + assert parallelism_vals.count(-1) == 1, ( + f'Found unspecified values (-1) for more than one {parallelism_type} ' + ' parallelism axis. At most one axis can be unspecified.' + ) + + determined_val = target_product / np.prod(parallelism_vals) * -1 + + assert determined_val >= 1 and determined_val.is_integer, ( + 'Unspecified value unable to be determined with the given ' + f' {parallelism_type} parallelism values' + ) + + parallelism_vals[parallelism_vals.index(-1)] = int(determined_val) + + target_type = 'slices' if parallelism_type == 'DCN' else 'devices per slice' + + assert np.prod(parallelism_vals) == target_product, ( + f'Number of {target_type} {target_product} does not match the product' + f' of the {parallelism_type} parallelism {np.prod(parallelism_vals)}' + ) + + return parallelism_vals + + +# State initialization utils. +# ----------------------------------------------------------------------------- + + +def _to_array(x): + if not isinstance(x, jax.Array): + x = jnp.asarray(x) + return x + + +def setup_initial_state( + constructor: Callable[[TransformerConfig, jax.Array], Transformer], + tx, + config: TransformerConfig, + rng: jax.Array, + mesh: jax.sharding.Mesh, +) -> tuple[TrainState, TrainState]: + """We initialize the model and optimizer state, and optionally load from a + checkpoint as necessary. + + Args: + constructor: the model constructor + tx: the optax.GradientTransformation + config: config object + rng: jax.prng key + mesh: jax.devices() mesh + + Returns: + state: the initialized train state + state_mesh_annotations: the mesh annotations for the train state + """ + + @jax.jit + def sharded_init(): + model = constructor(config, rng) + graphdef, params = nnx.split(model, nnx.Param) + state = TrainState.create( + apply_fn=graphdef.apply, + params=params, + tx=tx, + graphdef=graphdef, + ) + state = jax.tree.map(_to_array, state) + state_spec = nnx.get_partition_spec(state) + state = jax.lax.with_sharding_constraint(state, state_spec) + return state + + # Initialization + with jax.sharding.use_mesh(mesh): + state = sharded_init() + + state_sharding = nnx.get_named_sharding(state, mesh) + return state, state_sharding