diff --git a/README.md b/README.md index 58b57d7f..d2bea897 100644 --- a/README.md +++ b/README.md @@ -64,12 +64,21 @@ huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir Need to manually modify the `config.json` in the checkpoint folder to make it a valid JSON file. (Replace `'` with `"`, remove the excessive `,` after the last item in the JSON object) +## Mixtral +### Get Mixtral Checkpoint from HuggingFace + +Please sign agreement on Huggingface website to access Mixtral checkpoints. Download Mixtral PyTorch checkpoint using huggingface-cli. Mixtral Tokenizer is included in the checkpoint. + +```bash +huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir $input_ckpt_dir +``` + ## Run weight safetensor convert ```bash export input_ckpt_dir=Original llama weights directory export output_ckpt_dir=The output directory -export model_name="llama-3" # or "llama-2", "gemma" +export model_name="llama-3" # or "llama-2", "gemma", "mixtral" export quantize_weights=True # Whether to quantize weights export quantize_type="int8_per_channel" # "quantize_weights" needs to be turned on. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, "int8_per_channel" is the default option if not specified. python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize_type=$quantize_type @@ -108,6 +117,11 @@ python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --m python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml ``` +## Mixtral 8x7b +```bash +python run_interactive.py --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml +``` + # Run the server Here is an example to run the server with llama2 7B config. diff --git a/convert_checkpoints.py b/convert_checkpoints.py index 1b3af726..52a4cc82 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -26,6 +26,7 @@ import hashlib import json import os +import re import time import torch @@ -37,6 +38,8 @@ from jetstream_pt.config import FLAGS from jetstream_pt.third_party.gemma import model as gemma_model from jetstream_pt.third_party.llama import model_exportable as llama_model +from jetstream_pt.third_party.mixtral import model as mixtral_model + from safetensors import safe_open from safetensors.torch import save_file @@ -123,6 +126,12 @@ def _quantize_state_dict( block_size = orig_block_size n_bit = orig_n_bit state_dict.update(updated_weights) + for k, v in state_dict.items(): + if "layers" in k and "layers.0" not in k: + continue + print( + f"After quantization the converted key: {k} and value: {v.shape} {v.dtype}" + ) return state_dict @@ -462,6 +471,89 @@ def _get_gemma_state_dict(input_ckpt_dir): return state_dict, model_config +def _get_mixtral_state_dict(input_ckpt_dir): + ckpt_files = list(input_ckpt_dir.glob("*.pt")) + assert len(ckpt_files) == 8, "only expect 8 ckpt file for Mistral model." + + start = time.perf_counter() + state_dict = {} + for file in sorted(ckpt_files): + ckpt = torch.load( + str(file), map_location="cpu", mmap=True, weights_only=True + ) + state_dict.update(ckpt) + end = time.perf_counter() + print(f"Loading checkpoints takes {end - start} seconds") + + for k, v in state_dict.items(): + if "layers" in k and "layers.0" not in k: + continue + print(f"The loaded key: {k} and value: {v.shape} {v.dtype}") + + config = json.loads((input_ckpt_dir / "config.json").read_text()) + print(f"Loaded config: {config}") + weight_map = { + "layers.{}.block_sparse_moe.w1": "layers.{}.block_sparse_moe.cond_ffn.w1", + "layers.{}.block_sparse_moe.w2": "layers.{}.block_sparse_moe.cond_ffn.w2", + "layers.{}.block_sparse_moe.w3": "layers.{}.block_sparse_moe.cond_ffn.w3", + } + for key in list(state_dict.keys()): + if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value: + assert ( + key == "freqs_cis" + ), "Only expect key 'freqs_cis' in the state_dict has complex dtype." + # Remove "freqs_cis" since it has complex dtype, and safetensor doesn't support it. + # The "freqs_cis" will be reconstructed when it's loaded by inference engine. + state_dict.pop(key) + continue + prefix_to_remove = "model." + new_key = key + if key.startswith(prefix_to_remove): + new_key = new_key.removeprefix(prefix_to_remove) + + if "layers" in key: + abstract_key = re.sub(r".(\d+).", ".{}.", key) + layer_num = re.search(r"\d+", key).group(0) + new_key = weight_map.get(abstract_key) + if new_key is None: + continue + new_key = new_key.format(layer_num) + + if new_key == key: + continue + + if "w1" in key or "w3" in key: + state_dict[new_key] = ( + state_dict.pop(key) + .reshape( + config["num_local_experts"], + config["intermediate_size"], + config["hidden_size"], + ) + .contiguous() + ) + elif "w2" in key: + state_dict[new_key] = ( + state_dict.pop(key) + .reshape( + config["num_local_experts"], + config["intermediate_size"], + config["hidden_size"], + ) + .permute(0, 2, 1) + .contiguous() + ) + elif "gate" in key: + state_dict[new_key] = state_dict.pop(key).contiguous() + else: + state_dict[new_key] = state_dict.pop(key) + for k, v in state_dict.items(): + if "layers" in k and "layers.0" not in k: + continue + print(f"The converted key: {k} and value: {v.shape} {v.dtype}") + return state_dict, config + + def main(argv) -> None: """merge weights""" @@ -473,6 +565,14 @@ def main(argv) -> None: quantize_embedding_weight_map = ( gemma_model.GemmaModel.get_quantized_embedding_weight_to_scaler_map() ) + elif FLAGS.model_name == "mixtral": + state_dict, params = _get_mixtral_state_dict(_INPUT_CHECKPOINT_DIR.value) + quantize_linear_weight_map = ( + mixtral_model.Transformer.get_quantized_linear_weight_to_scaler_map() + ) + quantize_embedding_weight_map = ( + mixtral_model.Transformer.get_quantized_embedding_weight_to_scaler_map() + ) else: state_dict, params = _get_llama_state_dict(_INPUT_CHECKPOINT_DIR.value) quantize_linear_weight_map = ( diff --git a/default_shardings/mixtral.yaml b/default_shardings/mixtral.yaml new file mode 100644 index 00000000..85908d23 --- /dev/null +++ b/default_shardings/mixtral.yaml @@ -0,0 +1,32 @@ + +# Sharding config for mixtral +# Sharding should either be an int between 0 and rank - 1 +# signifying the axis to shard or -1 / null signifying replicated + + +freqs_cis : -1 # torch.complex64 (2048, 64) +tok_embeddings.weight : 1 # torch.float32 (vocab_size, 4096) +tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096) +layers.*.attention.wo.weight_scaler : -1 # torch.bfloat16 (4096,) +layers.*.attention.wq.weight : 0 # torch.int8 (4096, 4096) +layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.attention.wk.weight : 0 # torch.int8 (4096, 4096) +layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.attention.wv.weight : 0 # torch.int8 (4096, 4096) +layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.attention.wqkv.weight : 0 # torch.int8 (4096, 4096) +layers.*.attention.wqkv.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.block_sparse_moe.gate.weight: -1 +layers.*.block_sparse_moe.gate.weight_scaler: -1 +layers.*.block_sparse_moe.cond_ffn.w1: 1 +layers.*.block_sparse_moe.cond_ffn.w1_scaler: 1 +layers.*.block_sparse_moe.cond_ffn.w2: 2 +layers.*.block_sparse_moe.cond_ffn.w2_scaler: -1 +layers.*.block_sparse_moe.cond_ffn.w3: 1 +layers.*.block_sparse_moe.cond_ffn.w3_scaler: 1 +layers.*.ffn_norm.weight : -1 # torch.float32 (4096,) +layers.*.attention_norm.weight : -1 # torch.float32 (4096,) +norm.weight : -1 # torch.float32 (4096,) +output.weight : 0 # torch.float32 (vocab_size, 4096) +output.weight_scaler : 0 # torch.float32 (4096,) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index bdf5fe41..354ed5d3 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -108,7 +108,13 @@ def create_engine_from_config_flags(): sharding_file_name = FLAGS.sharding_config if not sharding_file_name: sharding_file_name = ( - "llama" if FLAGS.model_name.startswith("llama") else "gemma" + "llama" + if FLAGS.model_name.startswith("llama") + else "gemma" + if FLAGS.model_name.startswith("gemma") + else "mixtral" + if FLAGS.model_name.startswith("mixtral") + else None ) if ( quant_config.enable_weight_quantization diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index cfa5d34f..68402722 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -37,6 +37,7 @@ from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData, QuantizationConfig from jetstream_pt.third_party.llama import model_exportable as llama_model, model_args from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model +from jetstream_pt.third_party.mixtral import config as mixtral_config, model as mixtral_model Mesh = jax.sharding.Mesh @@ -359,7 +360,6 @@ def _insert_wrap( start_insert = decode_state.current_position - prefix.seq_len tokens = decode_state.tokens.at[slot].set(prefix.token) - start_insert = start_insert % self.env.cache_sequence_length # pos < 0 update_indexes = ( @@ -641,12 +641,17 @@ def _load_from_safetensors(self, path): def _load_from_state_dict(self, path): state_dict = torch.load(path, map_location=torch.device("cpu")) weights = {} + print(f"Loaded keys are : {state_dict.keys()}") for key, model_weights in self.pt_model.state_dict().items(): + if key == "freqs_cis": + continue assert key in state_dict, f"key: {key} not found" - weights[key] = torchjax.from_torch(state_dict[key]) + weights[key] = torch_xla2.tensor.t2j(state_dict[key]) assert tuple(model_weights.shape) == tuple( weights[key].shape ), f"key: {key} error: {model_weights.shape} != {weights[key].shape}" + + weights["freqs_cis"] = torch_xla2.tensor.t2j(self.pt_model.freqs_cis) return weights # pylint: disable-next=all @@ -760,7 +765,7 @@ def create_pytorch_engine( ) -> PyTorchEngine: """Returns: The pytorch engine.""" - supported_models = ["llama-2", "llama-3", "gemma"] + supported_models = ["llama-2", "llama-3", "gemma", "mixtral"] if model_name not in supported_models: raise NotImplementedError( f"Model name should be one of{','.join(supported_models)}" @@ -772,7 +777,6 @@ def create_pytorch_engine( jax.config.update("jax_traceback_filtering", "off") torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 torch.set_default_dtype(torch_dtype) - checkpoint_format = "" checkpoint_path = "" @@ -797,8 +801,14 @@ def create_pytorch_engine( pt_model = None + sharding_file_name = "" if not sharding_config: - sharding_file_name = "llama" if model_name.startswith("llama") else "gemma" + if model_name.startswith("llama"): + sharding_file_name = "llama" + elif model_name.startswith("gemma"): + sharding_file_name = "gemma" + elif model_name.startswith("mixtral"): + sharding_file_name = "mixtral" sharding_config = os.path.join( "default_shardings", sharding_file_name + ".yaml" ) @@ -851,6 +861,18 @@ def create_pytorch_engine( env = JetEngineEnvironment(env_data) print(f"Enviroment variables: {vars(env)}") pt_model = gemma_model.GemmaModel(args, env) + elif model_name == "mixtral": + args = mixtral_config.ModelArgs.from_name("Mixtral-8x7B-v0.1") + args.device = "meta" + env_data.cache_shape = ( + batch_size, + args.n_local_heads, + max_cache_length, + args.dim // args.n_head, + ) + env_data.num_layers = args.n_layer + env = JetEngineEnvironment(env_data) + pt_model = mixtral_model.Transformer(args, env) else: raise RuntimeError(f"Model with name {model_name} not found") diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 47d4a697..124df690 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -200,10 +200,10 @@ def forward( ): """ tokens: the input token for decoding + input_pos: the decoding position relative to the start, which is the length of the decoding results caches: kv caches mask: causal mask to filter the attention results start: the starting position for each slot - input_pos: the decoding position relative to the start, which is the length of the decoding results ragged_batch_index: precomputed batch index for ragged attention ragged_block_index: precomputed block index for ragged attention """ diff --git a/jetstream_pt/third_party/mixtral/__init__.py b/jetstream_pt/third_party/mixtral/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/jetstream_pt/third_party/mixtral/config.py b/jetstream_pt/third_party/mixtral/config.py new file mode 100644 index 00000000..cf6ab3d1 --- /dev/null +++ b/jetstream_pt/third_party/mixtral/config.py @@ -0,0 +1,78 @@ +# pylint: disable-all +# # Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Mixtral model config +import dataclasses +from dataclasses import dataclass + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + num_experts: int = 8 + num_activated_experts: int = 2 + device: str = "meta" + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [ + config + for config in transformer_configs + if config in str(name).upper() or config in str(name) + ] + assert len(config) == 1, name + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "Mixtral-8x7B-v0.1": dict( + block_size=32768, + n_layer=32, + n_head=32, + n_local_heads=8, + dim=4096, + intermediate_size=14336, + rope_base=1000000.0, + num_experts=8, + num_activated_experts=2, + ), +} diff --git a/jetstream_pt/third_party/mixtral/model.py b/jetstream_pt/third_party/mixtral/model.py new file mode 100644 index 00000000..b0d8d573 --- /dev/null +++ b/jetstream_pt/third_party/mixtral/model.py @@ -0,0 +1,377 @@ +# pylint: disable-all +# # Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, List, Any + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F +from .config import ModelArgs, find_multiple +from jetstream_pt.layers import Attention, get_quantized_linear_layer, get_quantized_enbedding_layer + +import jax + + +class Transformer(nn.Module): + + def __init__(self, config: ModelArgs, env) -> None: + super().__init__() + self.config = config + self.env = env + + Embedding = get_quantized_enbedding_layer(env.quant_config) + self.tok_embeddings = Embedding( + config.vocab_size, config.dim, device=config.device + ) + self.layers = nn.ModuleList( + TransformerBlock(config, env) for _ in range(config.n_layer) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + LinearLayer = get_quantized_linear_layer(env.quant_config) + self.output = LinearLayer( + config.dim, config.vocab_size, bias=False, device=config.device + ) + + self.max_batch_size = -1 + self.max_seq_length = -1 + + # TODO(Consider refactor with other models) + freqs_cis = precompute_freqs_cis( + self.config.block_size, + self.config.dim // self.config.n_head, + self.config.rope_base, + ) + self.register_buffer("freqs_cis", freqs_cis) + + @torch.no_grad() + def forward( + self, + idx: Tensor, + input_pos: Optional[Tensor], + caches: List[Any], + mask, + start: Optional[Tensor] = None, + ragged_batch_index=None, + ragged_block_index=None, + ) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + end = None if start is None else (start + input_pos) % self.env.cache_len + with jax.named_scope("transformer_tok"): + x = self.tok_embeddings(idx) + with jax.named_scope("transformer_freq"): + bsz, seqlen = idx.shape + freqs_cis = self.freqs_cis[input_pos] + freqs_cis = freqs_cis.reshape(bsz, seqlen, -1) + assert len(caches) == len( + self.layers + ), f"Number of caches ({len(caches)}) and layers ({len(self.layers)}) dont match" + for layer, cache in zip(self.layers, caches): + with jax.named_scope("TransformerBlock"): + x = layer( + x, + freqs_cis, + mask, + cache, + start, + end, + ragged_batch_index, + ragged_block_index, + ) + + with jax.named_scope("transformer_norm"): + x = self.norm(x) + logits = self.output(x) + return logits + + @staticmethod + def get_quantized_linear_weight_to_scaler_map(): + return { + "attention.wq.weight": "attention.wq.weight_scaler", + "attention.wk.weight": "attention.wk.weight_scaler", + "attention.wv.weight": "attention.wv.weight_scaler", + "attention.wo.weight": "attention.wo.weight_scaler", + "output.weight": "output.weight_scaler", + "block_sparse_moe.gate.weight": "block_sparse_moe.gate.weight_scaler", + "block_sparse_moe.cond_ffn.w1": "block_sparse_moe.cond_ffn.w1_scaler", + "block_sparse_moe.cond_ffn.w2": "block_sparse_moe.cond_ffn.w2_scaler", + "block_sparse_moe.cond_ffn.w3": "block_sparse_moe.cond_ffn.w3_scaler", + } + + @staticmethod + def get_quantized_embedding_weight_to_scaler_map(): + return { + "tok_embeddings.weight": "tok_embeddings.weight_scaler", + } + + @staticmethod + def get_weight_sharding_type(): + # ParallelEmbedding is col partitioned across the shards. + # ColumnParallelLinear is row partitioned across shards due to transpose. + # RowParallelLinear is col partitioned across shards due to transpose. + # None is no partitioning and tensor should be identical across shards + return { + "tok_embeddings.weight": "ParallelEmbedding", + "rope.freqs": None, + "attention.wq.weight": "ColumnParallelLinear", + "attention.wk.weight": "ColumnParallelLinear", + "attention.wv.weight": "ColumnParallelLinear", + "attention.wo.weight": "RowParallelLinear", + "feed_forward.w1.weight": "ColumnParallelLinear", + "feed_forward.w2.weight": "RowParallelLinear", + "feed_forward.w3.weight": "ColumnParallelLinear", + "attention_norm.weight": None, + "ffn_norm.weight": None, + "norm.weight": None, + "output.weight": "ColumnParallelLinear", + } + + +class TransformerBlock(nn.Module): + + def __init__(self, config: ModelArgs, env) -> None: + super().__init__() + self.attention = Attention( + config.n_head, + config.n_local_heads, + config.head_dim, + config.dim, + env=env, + device=config.device, + ) + self.block_sparse_moe = MOEFeedForward(config, config.device, env) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + caches: List[Tensor], + start=None, + end=None, + ragged_batch_index=None, + ragged_block_index=None, + ) -> Tensor: + with jax.named_scope("Attention"): + attn = self.attention( + self.attention_norm(x), + freqs_cis, + mask, + caches, + start, + end, + ragged_batch_index, + ragged_block_index, + ) + with jax.named_scope("ffn_norm"): + h = x + attn + ffns = self.ffn_norm(h) + with jax.named_scope("ffn"): + moe = self.block_sparse_moe(ffns) + out = h + moe + return out + + +class Int8ConditionalFeedForward(nn.Module): + + def __init__(self, config): + super().__init__() + w1 = torch.empty( + config.num_experts, + config.intermediate_size, + config.dim, + dtype=torch.int8, + ) + w2 = torch.empty( + config.num_experts, + config.dim, + config.intermediate_size, + dtype=torch.int8, + ) + w3 = torch.empty( + config.num_experts, + config.intermediate_size, + config.dim, + dtype=torch.int8, + ) + self.register_buffer("w1", w1) + self.register_buffer("w2", w2) + self.register_buffer("w3", w3) + + w1_scaler = torch.empty(config.num_experts, config.intermediate_size) + w2_scaler = torch.empty(config.num_experts, config.dim) + w3_scaler = torch.empty(config.num_experts, config.intermediate_size) + self.register_buffer("w1_scaler", w1_scaler) + self.register_buffer("w2_scaler", w2_scaler) + self.register_buffer("w3_scaler", w3_scaler) + + def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: + seq_len = x.shape[0] + if seq_len >= 4: + return self.forward_for_long_seq_len(x, expert_indices) + else: + return self.forward_for_short_seq_len(x, expert_indices) + + def forward_for_short_seq_len( + self, x: Tensor, expert_indices: Tensor + ) -> Tensor: + with jax.named_scope("conditional_ff"): + w1_weights = self.w1[expert_indices] # [T, A, D, D] + w3_weights = self.w3[expert_indices] # [T, A, D, D] + w2_weights = self.w2[expert_indices] # [T, A, D, D] + w1_scaler = self.w1_scaler[expert_indices] + w2_scaler = self.w2_scaler[expert_indices] + w3_scaler = self.w3_scaler[expert_indices] + + x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights) * w1_scaler) + x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) * w3_scaler + expert_outs = ( + torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) * w2_scaler + ) + return expert_outs + + def forward_for_long_seq_len(self, x, expert_indices): + seqlen = x.shape[0] + num_experts = self.w1.shape[0] + + # e = total num of exp = 8 + # t = seqlen + # o = config.imtermediate size + # i = config.dim + with jax.named_scope("conditional_ff"): + x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler) + x3 = torch.einsum("ti, eoi-> teo", x, self.w3) * self.w3_scaler + expert_outs = ( + torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler + ) + # e = 8; need to reduce to 2 + seq_indexes = torch.arange(seqlen).unsqueeze(1) + return expert_outs[seq_indexes, expert_indices] + + +class ConditionalFeedForward(nn.Module): + + def __init__(self, config): + super().__init__() + # TODO(How to enable quantization?) + self.w1 = nn.Parameter( + torch.empty(config.num_experts, config.intermediate_size, config.dim) + ) + self.w2 = nn.Parameter( + torch.empty(config.num_experts, config.dim, config.intermediate_size) + ) + self.w3 = nn.Parameter( + torch.empty(config.num_experts, config.intermediate_size, config.dim) + ) + + def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: + seq_len = x.shape[0] + if seq_len >= 4: + return self.forward_for_long_seq_len(x, expert_indices) + else: + return self.forward_for_short_seq_len(x, expert_indices) + + def forward_for_short_seq_len( + self, x: Tensor, expert_indices: Tensor + ) -> Tensor: + with jax.named_scope("conditional_ff"): + w1_weights = self.w1[expert_indices] # [T, A, D, D] + w3_weights = self.w3[expert_indices] # [T, A, D, D] + w2_weights = self.w2[expert_indices] # [T, A, D, D] + + x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights)) + x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) + expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) + return expert_outs + + def forward_for_long_seq_len(self, x, expert_indices): + seqlen = x.shape[0] + num_experts = self.w1.shape[0] + + # e = total num of exp = 8 + # t = seqlen + # o = config.imtermediate size + # i = config.dim + with jax.named_scope("conditional_ff"): + x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1)) + x3 = torch.einsum("ti, eoi-> teo", x, self.w3) + expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) + # e = 8; need to reduce to 2 + seq_indexes = torch.arange(seqlen).unsqueeze(1) + return expert_outs[seq_indexes, expert_indices] + + +class MOEFeedForward(nn.Module): + + def __init__(self, config, device, env) -> None: + super().__init__() + LinearLayer = get_quantized_linear_layer(env.quant_config) + self.gate = LinearLayer(config.dim, config.num_experts, bias=False) + CondLayer = ( + Int8ConditionalFeedForward + if env.quant_config.enable_weight_quantization + else ConditionalFeedForward + ) + self.cond_ffn = CondLayer(config) + self.dim = config.dim + self.num_activated_experts = config.num_activated_experts + + def forward(self, x: Tensor) -> Tensor: + bsz, seq, hidden = x.shape + # [B, T, D], combine BT, for prefill B = 1, for decode, T = 1 + x = x.view(-1, self.dim) + # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts + # x: [T, D] + scores = self.gate(x) # [T, E] + expert_weights = F.softmax(scores, dim=-1) + expert_weights, expert_indices = torch.topk( + expert_weights, self.num_activated_experts, dim=-1 + ) # [T, A], [T, A] + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] + expert_outs = self.cond_ffn(x, expert_indices) + expert_outs = torch.einsum("tai,ta -> ti", expert_outs, expert_weights) + # Changes back to [B, T, D] + expert_outs = expert_outs.reshape(bsz, seq, hidden) + return expert_outs + + +class RMSNorm(nn.Module): + + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000 +) -> Tensor: + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis diff --git a/jetstream_pt/third_party/mixtral/model_original.py b/jetstream_pt/third_party/mixtral/model_original.py new file mode 100644 index 00000000..5087d35a --- /dev/null +++ b/jetstream_pt/third_party/mixtral/model_original.py @@ -0,0 +1,281 @@ +# pylint: disable-all +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F +from .config import ModelArgs, find_multiple + + +class KVCache(nn.Module): + + def __init__( + self, + max_batch_size, + max_seq_length, + n_heads, + head_dim, + dtype=torch.bfloat16, + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +class Transformer(nn.Module): + + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList( + TransformerBlock(config) for _ in range(config.n_layer) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length): + if ( + self.max_seq_length >= max_seq_length + and self.max_batch_size >= max_batch_size + ): + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, max_seq_length, self.config.n_local_heads, head_dim + ) + + self.freqs_cis = precompute_freqs_cis( + self.config.block_size, + self.config.dim // self.config.n_head, + self.config.rope_base, + ) + self.causal_mask = torch.tril( + torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) + ) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + x = self.tok_embeddings(idx) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.block_sparse_moe = MOEFeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward( + self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor + ) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.block_sparse_moe(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = ( + config.n_head + 2 * config.n_local_heads + ) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +class ConditionalFeedForward(nn.Module): + + def __init__(self, config): + super().__init__() + self.w1 = nn.Parameter( + torch.empty(config.num_experts, config.intermediate_size, config.dim) + ) + self.w2 = nn.Parameter( + torch.empty(config.num_experts, config.dim, config.intermediate_size) + ) + self.w3 = nn.Parameter( + torch.empty(config.num_experts, config.intermediate_size, config.dim) + ) + + def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: + # T = num_tokens, I = intermediate size, D = hidden dim, A = activated experts + w1_weights = self.w1[expert_indices] # [T, A, D, D] + w3_weights = self.w3[expert_indices] # [T, A, D, D] + w2_weights = self.w2[expert_indices] # [T, A, D, D] + # x: [T, D] + x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights)) + x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) + expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) + return expert_outs + + +class MOEFeedForward(nn.Module): + + def __init__(self, config, env=None) -> None: + super().__init__() + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) + self.cond_ffn = ConditionalFeedForward(config) + self.dim = config.dim + self.num_activated_experts = config.num_activated_experts + + def forward(self, x: Tensor) -> Tensor: + x = x.view(-1, self.dim) + # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts + # x: [T, D] + scores = self.gate(x) # [T, E] + expert_weights = F.softmax(scores, dim=-1) + expert_weights, expert_indices = torch.topk( + expert_weights, self.num_activated_experts, dim=-1 + ) # [T, A], [T, A] + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] + expert_outs = self.cond_ffn(x, expert_indices) + return torch.einsum("tai,ta -> ti", expert_outs, expert_weights) + + +class RMSNorm(nn.Module): + + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000 +) -> Tensor: + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] + - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/jetstream_pt/third_party/mixtral/tokenizer.model b/jetstream_pt/third_party/mixtral/tokenizer.model new file mode 100644 index 00000000..85c0803f Binary files /dev/null and b/jetstream_pt/third_party/mixtral/tokenizer.model differ