Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ jobs:
source install_everything.sh
- name: Run interactive (bf16)
run: |
JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=0 --quantize_kv_cache=0
JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama-2.yaml --quantize_weights=0 --quantize_kv_cache=0
- name: Run interactive (int8)
run: |
JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=1 --quantize_type="int8_per_channel" --quantize_kv_cache=1
JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama-2.yaml --quantize_weights=1 --quantize_type="int8_per_channel" --quantize_kv_cache=1
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,22 @@ export tokenizer_path=tokenizer model file path

## Llama-2 7b
```bash
python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
```

## Llama-2 13b
```bash
python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
```

## Llama-3 8b
```bash
python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
```

## Llama-3 70b
```bash
python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
```

## Gemma 7b
Expand Down
8 changes: 5 additions & 3 deletions convert_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,19 @@ def _merge_llama_weights(
f"{len(tensors)} shards (shape = {tensors[0].shape}) for {key})"
)
state_dict_for_key = {}
for pattern, kind in llama_model.get_weight_sharding_type.items():
for pattern, kind in llama_model.Transformer.get_weight_sharding_type(
model_name=FLAGS.model_name
).items():
if not key.endswith(pattern):
continue
with torch.no_grad():
if kind in ("ParallelEmbedding", "RowParallelLinear"):
state_dict_for_key[key] = torch.cat(tensors, 1)
elif kind == "ColumnParallelLinear":
elif kind in ("ColumnParallelLinear", "VocabParallelEmbedding"):
state_dict_for_key[key] = torch.cat(tensors, 0)
else:
if not all(
torch.allclose(tensors[0], tensor, atol=1e-6)
torch.allclose(tensors[0], tensor, atol=1e-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to loose condition by e**4 magnitude?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The layer norm weights in llama-3 are not consistent across shards. I don't know why is this the case. These weights are expected to be replicated. It errors out if we don't reduce the precision here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qihqi are you ok with 1e-2 gap? I feel it's risky when we loose condition by e**4 magnitude for a single tensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that is fine

for tensor in tensors[1:]
):
raise ValueError(
Expand Down
File renamed without changes.
28 changes: 28 additions & 0 deletions default_shardings/llama-3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

# Sharding config for llama-3
# Sharding should either be an int between 0 and rank - 1
# signifying the axis to shard or -1 / null signifying replicated


freqs_cis : -1 # torch.complex64 (2048, 64)
tok_embeddings.weight : 0 # torch.float32 (vocab_size, 4096)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sharding file seems to be the same as llama-2. What's the difference between the llama-2 and llama-3 sharding file?

From the change in convert_checkpoints.py, it seems that llama-3 weight is sharded in a different way. This sharding file is only used for model sharding during runtime.

If this is the case, we don't need to have another sharding yaml file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tok_embeddings.weight is sharded differently between llama-2 and llama-3. For llama-2, embeddings are sharded along axis 1 and for llama-3, they are sharded along axis 0. But I agree, that it shouldn't make a difference in accuracy during runtime. If you think that it is better to keep the same sharding for both of them then I can revert this change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they shouldnt be sharded differently -- the only difference would be performance; lets run with both and keep the faster one.

tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096)
layers.*.attention.wo.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.attention.wq.weight : 0 # torch.int8 (4096, 4096)
layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.attention.wk.weight : 0 # torch.int8 (4096, 4096)
layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.attention.wv.weight : 0 # torch.int8 (4096, 4096)
layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.feed_forward.w1.weight : 0 # torch.float32 (11008, 4096)
layers.*.feed_forward.w1.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.feed_forward.w2.weight : 1 # torch.float32 (4096, 11008)
layers.*.feed_forward.w2.weight_scaler : 0 # torch.bfloat16 (11008,)
layers.*.feed_forward.w3.weight : 0 # torch.float32 (11008, 4096)
layers.*.feed_forward.w3.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.attention_norm.weight : -1 # torch.float32 (4096,)
layers.*.ffn_norm.weight : -1 # torch.float32 (4096,)
norm.weight : -1 # torch.float32 (4096,)
output.weight : 0 # torch.float32 (vocab_size, 4096)
output.weight_scaler : 0 # torch.float32 (4096,)
13 changes: 13 additions & 0 deletions jetstream_pt/third_party/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ def get_arg(
"norm_eps": 1e-05,
"rope_theta": 500000.0,
}
elif model_name == "llama-3-70b":
data = {
"dim": 8192,
"ffn_dim_multiplier": 1.3,
"multiple_of": 4096,
"n_heads": 64,
"n_kv_heads": 8,
"n_layers": 80,
"norm_eps": 1e-05,
"vocab_size": 128256,
"rope_theta": 500000.0,
}

return ModelArgs(
max_seq_len=seqlen,
max_batch_size=batch_size,
Expand Down
15 changes: 12 additions & 3 deletions jetstream_pt/third_party/llama/model_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,17 @@ def get_quantized_embedding_weight_to_scaler_map():
}

@staticmethod
def get_weight_sharding_type():
def get_weight_sharding_type(model_name: str = ""):
# ParallelEmbedding is col partitioned across the shards.
# VocalParallelEmbedding is row partitioned across the shards.
# ColumnParallelLinear is row partitioned across shards due to transpose.
# RowParallelLinear is col partitioned across shards due to transpose.
# None is no partitioning and tensor should be identical across shards
return {
"tok_embeddings.weight": "ParallelEmbedding",
expected_model_names = ("llama-2", "llama-3")
assert (
model_name in expected_model_names
), f"Expected model_name to one of {expected_model_names}"
sharding_dict = {
"rope.freqs": None,
"attention.wq.weight": "ColumnParallelLinear",
"attention.wk.weight": "ColumnParallelLinear",
Expand All @@ -245,3 +249,8 @@ def get_weight_sharding_type():
"norm.weight": None,
"output.weight": "ColumnParallelLinear",
}
if model_name == "llama-2":
sharding_dict["tok_embeddings.weight"] = "ParallelEmbedding"
elif model_name == "llama-3":
sharding_dict["tok_embeddings.weight"] = "VocabParallelEmbedding"
return sharding_dict